From 2ef6751d916cd94efeb6b877a67d24f0b8add700 Mon Sep 17 00:00:00 2001 From: Philip Chmielowiec Date: Wed, 9 Oct 2024 10:48:44 -0500 Subject: [PATCH] investigate where failure --- uxarray/core/dataarray.py | 40 +++++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index aa4e9a7b4..537f3228c 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -78,6 +78,17 @@ def __init__(self, *args, uxgrid: Grid = None, **kwargs): super().__init__(*args, **kwargs) + # TODO: + def __array_wrap__(self, obj, context=None): + return UxDataArray(obj, uxgrid=self.uxgrid) + + # TODO: + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + results = super().__array_ufunc__(ufunc, method, *inputs, **kwargs) + if isinstance(results, xr.DataArray): + return UxDataArray(results, uxgrid=self.uxgrid) + return results + # declare various accessors plot = UncachedAccessor(UxDataArrayPlotAccessor) subset = UncachedAccessor(DataArraySubsetAccessor) @@ -109,17 +120,23 @@ def _copy(self, **kwargs): return copied + # def _replace(self, *args, **kwargs): + # """Override to make the result a complete instance of + # ``uxarray.UxDataArray``.""" + # da = super()._replace(*args, **kwargs) + # + # if isinstance(da, UxDataArray): + # da.uxgrid = self.uxgrid + # else: + # da = UxDataArray(da, uxgrid=self.uxgrid) + # + # return da + def _replace(self, *args, **kwargs): """Override to make the result a complete instance of ``uxarray.UxDataArray``.""" - da = super()._replace(*args, **kwargs) - - if isinstance(da, UxDataArray): - da.uxgrid = self.uxgrid - else: - da = UxDataArray(da, uxgrid=self.uxgrid) - - return da + result = super()._replace(*args, **kwargs) + return UxDataArray(result, uxgrid=self.uxgrid) @property def uxgrid(self): @@ -138,6 +155,11 @@ def uxgrid(self): def uxgrid(self, ugrid_obj): self._uxgrid = ugrid_obj + # TODO: + def copy(self, *args, **kwargs): + result = super().copy(*args, **kwargs) + return UxDataArray(result, uxgrid=self.uxgrid) + def to_geodataframe( self, periodic_elements: Optional[str] = "exclude", @@ -1121,3 +1143,5 @@ def _slice_from_grid(self, sliced_grid): dims=self.dims, attrs=self.attrs, ) + + # def where(self):