From 75d832e36418fb020448d4f53f1fbd0cd4bc66f0 Mon Sep 17 00:00:00 2001 From: ahijevyc Date: Sat, 7 Sep 2024 11:09:00 -0600 Subject: [PATCH 1/5] Neighborhood filter --- uxarray/core/dataarray.py | 116 +++++++++++++++++++++++++++++++++++++- uxarray/core/dataset.py | 36 ++++++++++++ 2 files changed, 150 insertions(+), 2 deletions(-) diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index e976c5816..aa7c82b27 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union, Hashable, Literal +from uxarray.constants import GRID_DIMS from uxarray.formatting_html import array_repr from html import escape @@ -1044,8 +1045,6 @@ def isel(self, ignore_grid=False, *args, **kwargs): > uxda.subset(n_node=[1, 2, 3]) """ - from uxarray.constants import GRID_DIMS - if any(grid_dim in kwargs for grid_dim in GRID_DIMS) and not ignore_grid: # slicing a grid-dimension through Grid object @@ -1102,3 +1101,116 @@ def _slice_from_grid(self, sliced_grid): dims=self.dims, attrs=self.attrs, ) + + def neighborhood_filter( + self, + func: Callable = np.mean, + r: float = 1.0, + ) -> UxDataArray: + """Apply neighborhood filter + Parameters: + ----------- + func: Callable, default=np.mean + Apply this function to neighborhood + r : float, default=1. + Radius of neighborhood. For spherical coordinates, the radius is in units of degrees, + and for cartesian coordinates, the radius is in meters. + Returns: + -------- + destination_data : np.ndarray + Filtered data. + """ + + if self._face_centered(): + data_mapping = "face centers" + elif self._node_centered(): + data_mapping = "nodes" + elif self._edge_centered(): + data_mapping = "edge centers" + else: + raise ValueError( + f"Data_mapping is not face, node, or edge. Could not define data_mapping." + ) + + # reconstruct because the cached tree could be built from + # face centers, edge centers or nodes. + tree = self.uxgrid.get_ball_tree(coordinates=data_mapping, reconstruct=True) + + coordinate_system = tree.coordinate_system + + if coordinate_system == "spherical": + if data_mapping == "nodes": + lon, lat = ( + self.uxgrid.node_lon.values, + self.uxgrid.node_lat.values, + ) + elif data_mapping == "face centers": + lon, lat = ( + self.uxgrid.face_lon.values, + self.uxgrid.face_lat.values, + ) + elif data_mapping == "edge centers": + lon, lat = ( + self.uxgrid.edge_lon.values, + self.uxgrid.edge_lat.values, + ) + else: + raise ValueError( + f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', " + f"but received: {data_mapping}" + ) + + dest_coords = np.c_[lon, lat] + + elif coordinate_system == "cartesian": + if data_mapping == "nodes": + x, y, z = ( + self.uxgrid.node_x.values, + self.uxgrid.node_y.values, + self.uxgrid.node_z.values, + ) + elif data_mapping == "face centers": + x, y, z = ( + self.uxgrid.face_x.values, + self.uxgrid.face_y.values, + self.uxgrid.face_z.values, + ) + elif data_mapping == "edge centers": + x, y, z = ( + self.uxgrid.edge_x.values, + self.uxgrid.edge_y.values, + self.uxgrid.edge_z.values, + ) + else: + raise ValueError( + f"Invalid data_mapping. Expected 'nodes', 'edge centers', or 'face centers', " + f"but received: {data_mapping}" + ) + + dest_coords = np.c_[x, y, z] + + else: + raise ValueError( + f"Invalid coordinate_system. Expected either 'spherical' or 'cartesian', but received {coordinate_system}" + ) + + neighbor_indices = tree.query_radius(dest_coords, r=r) + + destination_data = np.empty(self.data.shape) + + # assert last dimension is a GRID dimension. + assert self.dims[-1] in GRID_DIMS, ( + f"expected last dimension of uxDataArray {self.data.dims[-1]} " + f"to be one of {GRID_DIMS}" + ) + # Apply function to indices on last axis. + for i, idx in enumerate(neighbor_indices): + if len(idx): + destination_data[..., i] = func(self.data[..., idx]) + + # construct data array for filtered variable + uxda_filter = self._copy() + + uxda_filter.data = destination_data + + return uxda_filter diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index 9a2f522a0..4f2786704 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -7,6 +7,7 @@ from typing import Optional, IO, Union +from uxarray.constants import GRID_DIMS from uxarray.grid import Grid from uxarray.core.dataarray import UxDataArray @@ -338,6 +339,41 @@ def to_array(self) -> UxDataArray: xarr = super().to_array() return UxDataArray(xarr, uxgrid=self.uxgrid) + def neighborhood_filter( + self, + func: Callable = np.mean, + r: float = 1.0, + ): + """Neighborhood function implementation for ``UxDataset``. + Parameters + --------- + func : Callable = np.mean + Apply this function to neighborhood + r : float, default=1. + Radius of neighborhood + """ + + + destination_uxds = self._copy() + # Loop through uxDataArrays in uxDataset + for var_name in self.data_vars: + uxda = self[var_name] + + # Skip if uxDataArray has no GRID dimension. + grid_dims = [dim for dim in uxda.dims if dim in GRID_DIMS] + if len(grid_dims) == 0: + continue + + # Put GRID dimension last for UxDataArray.neighborhood_filter. + remember_dim_order = uxda.dims + uxda = uxda.transpose(..., grid_dims[0]) + # Filter uxDataArray. + uxda = uxda.neighborhood_filter(func, r) + # Restore old dimension order. + destination_uxds[var_name] = uxda.transpose(*remember_dim_order) + + return destination_uxds + def nearest_neighbor_remap( self, destination_obj: Union[Grid, UxDataArray, UxDataset], From 8ec019328339dc92d719d7c28e87628de78de197 Mon Sep 17 00:00:00 2001 From: ahijevyc Date: Mon, 9 Sep 2024 10:48:09 -0600 Subject: [PATCH 2/5] ruff recommendations --- uxarray/core/dataarray.py | 8 ++++---- uxarray/core/dataset.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index 09764c400..8cf3fddf2 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -1131,7 +1131,7 @@ def neighborhood_filter( data_mapping = "edge centers" else: raise ValueError( - f"Data_mapping is not face, node, or edge. Could not define data_mapping." + "Data_mapping is not face, node, or edge. Could not define data_mapping." ) # reconstruct because the cached tree could be built from @@ -1202,9 +1202,9 @@ def neighborhood_filter( # assert last dimension is a GRID dimension. assert self.dims[-1] in GRID_DIMS, ( - f"expected last dimension of uxDataArray {self.data.dims[-1]} " - f"to be one of {GRID_DIMS}" - ) + f"expected last dimension of uxDataArray {self.data.dims[-1]} " + f"to be one of {GRID_DIMS}" + ) # Apply function to indices on last axis. for i, idx in enumerate(neighbor_indices): if len(idx): diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index 4f2786704..dbf840903 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -353,7 +353,6 @@ def neighborhood_filter( Radius of neighborhood """ - destination_uxds = self._copy() # Loop through uxDataArrays in uxDataset for var_name in self.data_vars: From 5605949bfaa4bfc2a86c801f7103e34b6fdb765b Mon Sep 17 00:00:00 2001 From: ahijevyc Date: Mon, 9 Sep 2024 10:57:41 -0600 Subject: [PATCH 3/5] added Callable to Type checking --- uxarray/core/dataarray.py | 2 +- uxarray/core/dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index 8cf3fddf2..a9076d7ab 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -4,7 +4,7 @@ import numpy as np -from typing import TYPE_CHECKING, Optional, Union, Hashable, Literal +from typing import TYPE_CHECKING, Callable, Optional, Union, Hashable, Literal from uxarray.constants import GRID_DIMS from uxarray.formatting_html import array_repr diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index dbf840903..f4c259297 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -5,7 +5,7 @@ import sys -from typing import Optional, IO, Union +from typing import Callable, Optional, IO, Union from uxarray.constants import GRID_DIMS from uxarray.grid import Grid From 0c7bc1eae7474f71c8b00a5a2060c5d3c08b2d6e Mon Sep 17 00:00:00 2001 From: ahijevyc Date: Mon, 9 Sep 2024 15:10:27 -0600 Subject: [PATCH 4/5] np.vstack().T faster than np.c --- uxarray/core/dataarray.py | 4 ++-- uxarray/core/dataset.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index a9076d7ab..8fae278a9 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -1162,7 +1162,7 @@ def neighborhood_filter( f"but received: {data_mapping}" ) - dest_coords = np.c_[lon, lat] + dest_coords = np.vstack((lon, lat)).T elif coordinate_system == "cartesian": if data_mapping == "nodes": @@ -1189,7 +1189,7 @@ def neighborhood_filter( f"but received: {data_mapping}" ) - dest_coords = np.c_[x, y, z] + dest_coords = np.vstack((x, y, z)).T else: raise ValueError( diff --git a/uxarray/core/dataset.py b/uxarray/core/dataset.py index f4c259297..2489c23ab 100644 --- a/uxarray/core/dataset.py +++ b/uxarray/core/dataset.py @@ -350,7 +350,8 @@ def neighborhood_filter( func : Callable = np.mean Apply this function to neighborhood r : float, default=1. - Radius of neighborhood + Radius of neighborhood. For spherical coordinates, the radius is in units of degrees, + and for cartesian coordinates, the radius is in meters. """ destination_uxds = self._copy() From d6d8a33faa8f64dec3fb930ef9ba53f25b63ff1d Mon Sep 17 00:00:00 2001 From: ahijevyc Date: Mon, 9 Sep 2024 15:25:26 -0600 Subject: [PATCH 5/5] Fix some comments --- uxarray/core/dataarray.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index 8fae278a9..a0b61931c 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -1198,9 +1198,10 @@ def neighborhood_filter( neighbor_indices = tree.query_radius(dest_coords, r=r) + # Construct numpy array for filtered variable. destination_data = np.empty(self.data.shape) - # assert last dimension is a GRID dimension. + # Assert last dimension is a GRID dimension. assert self.dims[-1] in GRID_DIMS, ( f"expected last dimension of uxDataArray {self.data.dims[-1]} " f"to be one of {GRID_DIMS}" @@ -1210,7 +1211,7 @@ def neighborhood_filter( if len(idx): destination_data[..., i] = func(self.data[..., idx]) - # construct data array for filtered variable + # Construct UxDataArray for filtered variable. uxda_filter = self._copy() uxda_filter.data = destination_data