Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: Unstructured to Structured Nearest Neighbor Remapping #892

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,308 changes: 2,308 additions & 0 deletions docs/user-guide/remapping-u2s.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
"cell_type": "markdown",
"id": "d9d3f5a8-6d3c-4a7e-9150-a2915f3e0ceb",
"metadata": {},
"source": [
"# Remapping"
]
"source": "# Remapping (Unstructured to Unstructured)"
},
{
"cell_type": "markdown",
Expand All @@ -15,8 +13,9 @@
"source": [
"Remapping, or commonly referred to as Regridding, is the process of taking data that resides on one grid and mapping it to another. Details on various remapping methods can be found [here](https://climatedataguide.ucar.edu/climate-tools/regridding-overview). This user guide section will cover the two native remapping methods that are supported by UXarray:\n",
"\n",
"* Nearest Neighbor\n",
"* Inverse Distance Weighted"
"* Nearest Neighbor \n",
"* Inverse Distance Weighted \n",
"\n"
]
},
{
Expand Down Expand Up @@ -195,6 +194,12 @@
"uxds_120[\"bottomDepth\"].remap"
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Unstructured to Unstructured Remapping",
"id": "bc33fa75540a80b0"
},
{
"cell_type": "markdown",
"id": "0e7f1fc3-090f-4aa4-86b3-de240dc68909",
Expand Down
8 changes: 6 additions & 2 deletions docs/userguide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ These user guides provide detailed explanations of the core functionality in UXa
`Subsetting <user-guide/subset.ipynb>`_
Select specific regions of a grid

`Remapping <user-guide/remapping.ipynb>`_
Remap (a.k.a Regrid) between unstructured grids
`Remapping (Unstructured to Unstructured) <user-guide/remapping-u2u.ipynb>`_
Remap/Regrid between unstructured grids

`Remapping (Unstructured to Structured) <user-guide/remapping-u2s.ipynb>`_
Remap/Regrid from unstructured to structured grids

`Topological Aggregations <user-guide/topological-aggregations.ipynb>`_
Aggregate data across grid dimensions
Expand All @@ -58,6 +61,7 @@ These user guides provide detailed explanations of the core functionality in UXa
`Face Area Calculations <user-guide/area_calc.ipynb>`_
Methods for computing the area of each face


Supplementary Guides
--------------------

Expand Down
12 changes: 12 additions & 0 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,18 @@ def uxgrid(self):
def uxgrid(self, ugrid_obj):
self._uxgrid = ugrid_obj

def to_structured(
self, lon: np.ndarray, lat: np.ndarray, method: str = "nearest neighbor"
):
# add checks for (-180, 180), (-90, 90)

destination_grid = xr.Dataset(coords=dict(lon=("lon", lon), lat=("lat", lat)))

if method == "nearest neighbor":
return self.remap.nearest_neighbor(destination_grid)
else:
raise ValueError("TODO")

def to_geodataframe(self, override=False, cache=True, exclude_antimeridian=False):
"""Constructs a ``spatialpandas.GeoDataFrame`` with a "geometry"
column, containing a collection of Shapely Polygons or MultiPolygons
Expand Down
12 changes: 7 additions & 5 deletions uxarray/remap/dataarray_accessor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
from warnings import warn

from uxarray.remap.nearest_neighbor import _nearest_neighbor_uxda
Expand All @@ -10,6 +10,7 @@
if TYPE_CHECKING:
from uxarray.core.dataset import UxDataset
from uxarray.core.dataarray import UxDataArray
from xarray import Dataset

from uxarray.grid import Grid

Expand All @@ -31,10 +32,11 @@ def __repr__(self):

def nearest_neighbor(
self,
destination_grid: Optional[Grid] = None,
destination_obj: Optional[Grid, UxDataArray, UxDataset] = None,
destination_grid: Optional[Grid, Dataset] = None,
destination_obj: Optional[Grid, UxDataArray, UxDataset, Dataset] = None,
remap_to: str = "face centers",
coord_type: str = "spherical",
coord_names: Union[tuple, list] = ("lon", "lat"),
):
"""Nearest Neighbor Remapping between a source (``UxDataArray``) and
destination.`.
Expand All @@ -60,15 +62,15 @@ def nearest_neighbor(

if destination_grid is not None:
return _nearest_neighbor_uxda(
self.uxda, destination_grid, remap_to, coord_type
self.uxda, destination_grid, remap_to, coord_type, coord_names
)
elif destination_obj is not None:
warn(
"destination_obj will be deprecated in a future release. Please use destination_grid instead.",
DeprecationWarning,
)
return _nearest_neighbor_uxda(
self.uxda, destination_obj, remap_to, coord_type
self.uxda, destination_obj, remap_to, coord_type, coord_names
)

def inverse_distance_weighted(
Expand Down
230 changes: 168 additions & 62 deletions uxarray/remap/nearest_neighbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
if TYPE_CHECKING:
from uxarray.core.dataset import UxDataset
from uxarray.core.dataarray import UxDataArray
from xarray import Dataset

import numpy as np
import xarray as xr


import uxarray.core.dataarray
import uxarray.core.dataset
Expand Down Expand Up @@ -148,11 +151,94 @@ def _nearest_neighbor(
return destination_data


def _nearest_neighbor_u2s(
source_grid: Grid,
destination_grid: Dataset,
source_data: UxDataArray,
coord_type: str = "spherical",
coord_names: Union[tuple, list] = ("lon", "lat"),
):
if any(coord_name not in destination_grid for coord_name in coord_names):
raise ValueError("TODO:")

n_elements = source_data.shape[-1]

if n_elements == source_grid.n_node:
source_data_mapping = "nodes"
elif n_elements == source_grid.n_edge:
source_data_mapping = "edge centers"
elif n_elements == source_grid.n_face:
source_data_mapping = "face centers"
else:
raise ValueError(
f"Invalid source_data shape. The final dimension should be either match the number of corner "
f"nodes ({source_grid.n_node}), edge centers ({source_grid.n_edge}), or face centers ({source_grid.n_face}) in the"
f" source grid, but received: {source_data.shape}"
)

if coord_type == "spherical":
_source_tree = source_grid.get_ball_tree(coordinates=source_data_mapping)

lon_grid, lat_grid = np.meshgrid(
destination_grid[coord_names[0]].values,
destination_grid[coord_names[1]].values,
)

grid_shape = lon_grid.shape

coords = np.vstack([lon_grid.ravel(), lat_grid.ravel()]).T

elif coord_type == "cartesian":
_source_tree = source_grid.get_ball_tree(
coordinates=source_data_mapping,
coordinate_system="cartesian",
distance_metric="minkowski",
)

x_grid, y_grid, z_grid = np.meshgrid(
destination_grid[coord_names[0]].values,
destination_grid[coord_names[1]].values,
destination_grid[coord_names[2]].values,
)

grid_shape = x_grid.shape

coords = np.vstack([x_grid.ravel(), y_grid.ravel(), z_grid.ravel()]).T

else:
raise ValueError("TODO: ")

# get nearest neighbor indices
_, nearest_neighbor_indices = _source_tree.query(coords, k=1)

# data values from source data to destination data using nearest neighbor indices
if nearest_neighbor_indices.ndim > 1:
nearest_neighbor_indices = nearest_neighbor_indices.squeeze()

# support arbitrary dimension data using Ellipsis "..."
destination_data = source_data.data[..., nearest_neighbor_indices]

if source_data.ndim == 1:
# case for 1D slice of data
destination_data = destination_data.reshape(grid_shape)
else:
destination_data = destination_data.reshape((-1) + grid_shape)

remapped_var = xr.DataArray(
data=destination_data,
dims=source_data.dims[:-1] + coord_names[::-1],
name=source_data.name,
)

return remapped_var


def _nearest_neighbor_uxda(
source_uxda: UxDataArray,
destination_obj: Union[Grid, UxDataArray, UxDataset],
destination_obj: Union[Grid, UxDataArray, UxDataset, Dataset],
remap_to: str = "face centers",
coord_type: str = "spherical",
coord_names: Union[tuple, list] = ("lon", "lat"),
):
"""Nearest Neighbor Remapping implementation for ``UxDataArray``.

Expand All @@ -167,63 +253,72 @@ def _nearest_neighbor_uxda(
coord_type : str, default="spherical"
Indicates whether to remap using on Spherical or Cartesian coordinates for nearest neighbor computations when
remapping.
coord_names: str
TODO
"""

# prepare dimensions
if remap_to == "nodes":
destination_dim = "n_node"
elif remap_to == "edge centers":
destination_dim = "n_edge"
else:
destination_dim = "n_face"

destination_dims = list(source_uxda.dims)
destination_dims[-1] = destination_dim

if isinstance(destination_obj, Grid):
destination_grid = destination_obj
elif isinstance(
destination_obj,
(uxarray.core.dataarray.UxDataArray, uxarray.core.dataset.UxDataset),
):
destination_grid = destination_obj.uxgrid
else:
raise ValueError("TODO: Invalid Input")
if not isinstance(destination_obj, xr.Dataset):
# Unstructured to Unstructured Case
if remap_to == "nodes":
destination_dim = "n_node"
elif remap_to == "edge centers":
destination_dim = "n_edge"
else:
destination_dim = "n_face"

destination_dims = list(source_uxda.dims)
destination_dims[-1] = destination_dim

if isinstance(destination_obj, Grid):
destination_grid = destination_obj
elif isinstance(
destination_obj,
(uxarray.core.dataarray.UxDataArray, uxarray.core.dataset.UxDataset),
):
destination_grid = destination_obj.uxgrid
else:
raise ValueError("TODO: Invalid Input")

# perform remapping
destination_data = _nearest_neighbor(
source_uxda.uxgrid, destination_grid, source_uxda.data, remap_to, coord_type
)
# construct data array for remapping variable
uxda_remap = uxarray.core.dataarray.UxDataArray(
data=destination_data,
name=source_uxda.name,
coords=source_uxda.coords,
dims=destination_dims,
uxgrid=destination_grid,
)
# add remapped variable to existing UxDataset
if isinstance(destination_obj, uxarray.core.dataset.UxDataset):
uxds = destination_obj.copy()
uxds[source_uxda.name] = uxda_remap
return uxds

# construct a UxDataset from remapped variable and existing variable
elif isinstance(destination_obj, uxarray.core.dataset.UxDataArray):
uxds = destination_obj.copy().to_dataset()
uxds[source_uxda.name] = uxda_remap
return uxds

# return UxDataArray with remapped variable
# perform remapping
destination_data = _nearest_neighbor(
source_uxda.uxgrid, destination_grid, source_uxda.data, remap_to, coord_type
)
# construct data array for remapping variable
uxda_remap = uxarray.core.dataarray.UxDataArray(
data=destination_data,
name=source_uxda.name,
coords=source_uxda.coords,
dims=destination_dims,
uxgrid=destination_grid,
)
# add remapped variable to existing UxDataset
if isinstance(destination_obj, uxarray.core.dataset.UxDataset):
uxds = destination_obj.copy()
uxds[source_uxda.name] = uxda_remap
return uxds

# construct a UxDataset from remapped variable and existing variable
elif isinstance(destination_obj, uxarray.core.dataset.UxDataArray):
uxds = destination_obj.copy().to_dataset()
uxds[source_uxda.name] = uxda_remap
return uxds

# return UxDataArray with remapped variable
else:
return uxda_remap
else:
return uxda_remap
remapped_var = _nearest_neighbor_u2s(
source_uxda.uxgrid, destination_obj, source_uxda, coord_type, coord_names
)
return remapped_var


def _nearest_neighbor_uxds(
source_uxds: UxDataset,
destination_obj: Union[Grid, UxDataArray, UxDataset],
destination_obj: Union[Grid, UxDataArray, UxDataset, Dataset],
remap_to: str = "face centers",
coord_type: str = "spherical",
coord_names: Union[tuple, list] = ("lon", "lat"),
):
"""Nearest Neighbor Remapping implementation for ``UxDataset``.

Expand All @@ -238,19 +333,30 @@ def _nearest_neighbor_uxds(
coord_type : str, default="spherical"
Indicates whether to remap using on Spherical or Cartesian coordinates
"""
if not isinstance(destination_obj, xr.Dataset):
if isinstance(destination_obj, Grid):
destination_uxds = uxarray.core.dataset.UxDataset(uxgrid=destination_obj)
elif isinstance(destination_obj, uxarray.core.dataset.UxDataArray):
destination_uxds = destination_obj.to_dataset()
elif isinstance(destination_obj, uxarray.core.dataset.UxDataset):
destination_uxds = destination_obj
else:
raise ValueError

if isinstance(destination_obj, Grid):
destination_uxds = uxarray.core.dataset.UxDataset(uxgrid=destination_obj)
elif isinstance(destination_obj, uxarray.core.dataset.UxDataArray):
destination_uxds = destination_obj.to_dataset()
elif isinstance(destination_obj, uxarray.core.dataset.UxDataset):
destination_uxds = destination_obj
else:
raise ValueError

for var_name in source_uxds.data_vars:
destination_uxds = _nearest_neighbor_uxda(
source_uxds[var_name], destination_uxds, remap_to, coord_type
)
for var_name in source_uxds.data_vars:
# TODO: bug here ?
destination_uxds = _nearest_neighbor_uxda(
source_uxds[var_name], destination_uxds, remap_to, coord_type
)

return destination_uxds
return destination_uxds
else:
for var_name in source_uxds.data_vars:
destination_obj[var_name] = _nearest_neighbor_uxda(
source_uxds[var_name],
destination_obj,
remap_to,
coord_type,
coord_names,
)
return destination_obj
Loading