Skip to content

Commit

Permalink
ENH: support callable as stats in zonal_stats (#55)
Browse files Browse the repository at this point in the history
* ENH: support callable as stats in zonal_stats

* polygons -> geometry

* single thread to get proper coverage
  • Loading branch information
martinfleis authored Dec 15, 2023
1 parent 18ee579 commit 07a8e92
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 50 deletions.
57 changes: 31 additions & 26 deletions xvec/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
from collections.abc import Hashable, Mapping, Sequence
from typing import Any
from typing import Any, Callable

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -921,10 +921,10 @@ def to_geodataframe(

def zonal_stats(
self,
polygons: Sequence[shapely.Geometry],
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str = "mean",
stats: str | Callable = "mean",
name: Hashable = "geometry",
index: bool = None,
method: str = "rasterize",
Expand All @@ -934,37 +934,43 @@ def zonal_stats(
):
"""Extract the values from a dataset indexed by a set of geometries
The CRS of the raster and that of polygons need to be equal.
The CRS of the raster and that of geometry need to be equal.
Xvec does not verify their equality.
Parameters
----------
polygons : Sequence[shapely.Geometry]
geometry : Sequence[shapely.Geometry]
An arrray-like (1-D) of shapely geometries, like a numpy array or
:class:`geopandas.GeoSeries`.
:class:`geopandas.GeoSeries`. Polygon and LineString geometry types are
supported.
x_coords : Hashable
name of the coordinates containing ``x`` coordinates (i.e. the first value
in the coordinate pair encoding the vertex of the polygon)
y_coords : Hashable
name of the coordinates containing ``y`` coordinates (i.e. the second value
in the coordinate pair encoding the vertex of the polygon)
stats : string
Spatial aggregation statistic method, by default "mean". It supports the
following statistcs: ['mean', 'median', 'min', 'max', 'sum']
stats : string | Callable
Spatial aggregation statistic method, by default "mean". Any of the
aggregations available as :class:`xarray.DataArray` or
:class:`xarray.DataArrayGroupBy` methods like
:meth:`~xarray.DataArray.mean`, :meth:`~xarray.DataArray.min`,
:meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile`
are available. Alternatively, you can pass a ``Callable`` supported
by :meth:`~xarray.DataArray.reduce`.
name : Hashable, optional
Name of the dimension that will hold the ``polygons``, by default "geometry"
Name of the dimension that will hold the ``geometry``, by default "geometry"
index : bool, optional
If `polygons` is a GeoSeries, ``index=True`` will attach its index as another
If ``geometry`` is a :class:`~geopandas.GeoSeries`, ``index=True`` will attach its index as another
coordinate to the geometry dimension in the resulting object. If
``index=None``, the index will be stored if the `polygons.index` is a named
``index=None``, the index will be stored if the `geometry.index` is a named
or non-default index. If ``index=False``, it will never be stored. This is
useful as an attribute link between the resulting array and the GeoPandas
object from which the polygons are sourced.
object from which the geometry is sourced.
method : str, optional
The method of data extraction. The default is ``"rasterize"``, which uses
:func:`rasterio.features.rasterize` and is faster, but can lead to loss
of information in case of small polygons. Other option is ``"iterate"``, which
iterates over polygons and uses :func:`rasterio.features.geometry_mask`.
of information in case of small polygons or lines. Other option is ``"iterate"``, which
iterates over geometries and uses :func:`rasterio.features.geometry_mask`.
all_touched : bool, optional
If True, all pixels touched by geometries will be considered. If False, only
pixels whose center is within the polygon or that are selected by
Expand All @@ -975,22 +981,21 @@ def zonal_stats(
only if ``method="iterate"``.
**kwargs : optional
Keyword arguments to be passed to the aggregation function
(e.g., ``Dataset.mean(**kwargs)``).
(e.g., ``Dataset.quantile(**kwargs)``).
Returns
-------
Dataset
Dataset or DataArray
A subset of the original object with N-1 dimensions indexed by
the the GeometryIndex.
the :class:`GeometryIndex` of ``geometry``.
"""
# TODO: allow multiple stats at the same time (concat along a new axis),
# TODO: possibly as a list of tuples to include names?
# TODO: allow callable in stat (via .reduce())
if method == "rasterize":
result = _zonal_stats_rasterize(
self,
polygons=polygons,
geometry=geometry,
x_coords=x_coords,
y_coords=y_coords,
stats=stats,
Expand All @@ -1001,7 +1006,7 @@ def zonal_stats(
elif method == "iterate":
result = _zonal_stats_iterative(
self,
polygons=polygons,
geometry=geometry,
x_coords=x_coords,
y_coords=y_coords,
stats=stats,
Expand All @@ -1017,15 +1022,15 @@ def zonal_stats(
)

# save the index as a data variable
if isinstance(polygons, pd.Series):
if isinstance(geometry, pd.Series):
if index is None:
if polygons.index.name is not None or not polygons.index.equals(
pd.RangeIndex(0, len(polygons))
if geometry.index.name is not None or not geometry.index.equals(
pd.RangeIndex(0, len(geometry))
):
index = True
if index:
index_name = polygons.index.name if polygons.index.name else "index"
result = result.assign_coords({index_name: (name, polygons.index)})
index_name = geometry.index.name if geometry.index.name else "index"
result = result.assign_coords({index_name: (name, geometry.index)})

# standardize the shape - each method comes with a different one
return result.transpose(
Expand Down
26 changes: 26 additions & 0 deletions xvec/tests/test_zonal_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,29 @@ def test_crs(method):

actual = da.xvec.zonal_stats(polygons, "x", "y", stats="sum", method=method)
xr.testing.assert_identical(actual, expected)


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
def test_callable(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
ds_agg = ds.xvec.zonal_stats(
world.geometry, "longitude", "latitude", method=method, stats=np.nanstd
)
ds_std = ds.xvec.zonal_stats(
world.geometry, "longitude", "latitude", method=method, stats="std"
)
xr.testing.assert_identical(ds_agg, ds_std)

da_agg = ds.z.xvec.zonal_stats(
world.geometry,
"longitude",
"latitude",
method=method,
stats=np.nanstd,
n_jobs=1,
)
da_std = ds.z.xvec.zonal_stats(
world.geometry, "longitude", "latitude", method=method, stats="std"
)
xr.testing.assert_identical(da_agg, da_std)
62 changes: 38 additions & 24 deletions xvec/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import gc
from collections.abc import Hashable, Sequence
from typing import Callable

import numpy as np
import shapely
Expand All @@ -10,16 +11,16 @@

def _zonal_stats_rasterize(
acc,
polygons: Sequence[shapely.Geometry],
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str = "mean",
stats: str | Callable = "mean",
name: str = "geometry",
all_touched: bool = False,
**kwargs,
):
try:
import rasterio # noqa: F401
import rasterio
import rioxarray # noqa: F401
except ImportError as err:
raise ImportError(
Expand All @@ -28,15 +29,15 @@ def _zonal_stats_rasterize(
"'pip install rioxarray'."
) from err

if hasattr(polygons, "crs"):
crs = polygons.crs
if hasattr(geometry, "crs"):
crs = geometry.crs
else:
crs = None

transform = acc._obj.rio.transform()

labels = rasterio.features.rasterize(
zip(polygons, range(len(polygons))),
zip(geometry, range(len(geometry))),
out_shape=(
acc._obj[y_coords].shape[0],
acc._obj[x_coords].shape[0],
Expand All @@ -46,10 +47,13 @@ def _zonal_stats_rasterize(
all_touched=all_touched,
)
groups = acc._obj.groupby(xr.DataArray(labels, dims=(y_coords, x_coords)))
agg = getattr(groups, stats)(**kwargs)
if isinstance(stats, str):
agg = getattr(groups, stats)(**kwargs)
else:
agg = groups.reduce(stats, keep_attrs=True, **kwargs)
vec_cube = (
agg.reindex(group=range(len(polygons)))
.assign_coords(group=polygons)
agg.reindex(group=range(len(geometry)))
.assign_coords(group=geometry)
.rename(group=name)
).xvec.set_geom_indexes(name, crs=crs)

Expand All @@ -61,23 +65,23 @@ def _zonal_stats_rasterize(

def _zonal_stats_iterative(
acc,
polygons: Sequence[shapely.Geometry],
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str = "mean",
stats: str | Callable = "mean",
name: str = "geometry",
all_touched: bool = False,
n_jobs: int = -1,
**kwargs,
):
"""Extract the values from a dataset indexed by a set of geometries
The CRS of the raster and that of polygons need to be equal.
The CRS of the raster and that of geometry need to be equal.
Xvec does not verify their equality.
Parameters
----------
polygons : Sequence[shapely.Geometry]
geometry : Sequence[shapely.Geometry]
An arrray-like (1-D) of shapely geometries, like a numpy array or
:class:`geopandas.GeoSeries`.
x_coords : Hashable
Expand All @@ -87,10 +91,14 @@ def _zonal_stats_iterative(
name of the coordinates containing ``y`` coordinates (i.e. the second value
in the coordinate pair encoding the vertex of the polygon)
stats : Hashable
Spatial aggregation statistic method, by default "mean". It supports the
following statistcs: ['mean', 'median', 'min', 'max', 'sum']
Spatial aggregation statistic method, by default "mean". Any of the
aggregations available as DataArray or DataArrayGroupBy like
:meth:`~xarray.DataArray.mean`, :meth:`~xarray.DataArray.min`,
:meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile`,
methods are available. Alternatively, you can pass a ``Callable`` supported
by :meth:`~xarray.DataArray.reduce`.
name : Hashable, optional
Name of the dimension that will hold the ``polygons``, by default "geometry"
Name of the dimension that will hold the ``geometry``, by default "geometry"
all_touched : bool, optional
If True, all pixels touched by geometries will be considered. If False, only
pixels whose center is within the polygon or that are selected by
Expand Down Expand Up @@ -140,14 +148,14 @@ def _zonal_stats_iterative(
all_touched=all_touched,
**kwargs,
)
for geom in polygons
for geom in geometry
)
if hasattr(polygons, "crs"):
crs = polygons.crs
if hasattr(geometry, "crs"):
crs = geometry.crs
else:
crs = None
vec_cube = xr.concat(
zonal, dim=xr.DataArray(polygons, name=name, dims=name)
zonal, dim=xr.DataArray(geometry, name=name, dims=name)
).xvec.set_geom_indexes(name, crs=crs)
gc.collect()

Expand All @@ -160,7 +168,7 @@ def _agg_geom(
trans,
x_coords: str = None,
y_coords: str = None,
stats: str = "mean",
stats: str | Callable = "mean",
all_touched=False,
**kwargs,
):
Expand Down Expand Up @@ -207,9 +215,15 @@ def _agg_geom(
invert=True,
all_touched=all_touched,
)
result = getattr(
acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords))), stats
)(dim=(y_coords, x_coords), keep_attrs=True, **kwargs)
masked = acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords)))
if isinstance(stats, str):
result = getattr(masked, stats)(
dim=(y_coords, x_coords), keep_attrs=True, **kwargs
)
else:
result = masked.reduce(
stats, dim=(y_coords, x_coords), keep_attrs=True, **kwargs
)

del mask
gc.collect()
Expand Down

0 comments on commit 07a8e92

Please sign in to comment.