Skip to content

Commit

Permalink
ENH: multiple aggregations at once in zonal_stats (#56)
Browse files Browse the repository at this point in the history
* ENH: multiple aggregations at once in zonal_stats

* expand doctring

* fix typing
  • Loading branch information
martinfleis authored Dec 15, 2023
1 parent 07a8e92 commit e86d3ec
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 33 deletions.
41 changes: 20 additions & 21 deletions xvec/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ def zonal_stats(
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str | Callable = "mean",
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
name: Hashable = "geometry",
index: bool = None,
method: str = "rasterize",
Expand All @@ -949,36 +949,39 @@ def zonal_stats(
y_coords : Hashable
name of the coordinates containing ``y`` coordinates (i.e. the second value
in the coordinate pair encoding the vertex of the polygon)
stats : string | Callable
stats : string | Callable | Sequence[str | Callable | tuple]
Spatial aggregation statistic method, by default "mean". Any of the
aggregations available as :class:`xarray.DataArray` or
:class:`xarray.DataArrayGroupBy` methods like
:meth:`~xarray.DataArray.mean`, :meth:`~xarray.DataArray.min`,
:meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile`
are available. Alternatively, you can pass a ``Callable`` supported
by :meth:`~xarray.DataArray.reduce`.
:meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile` are
available. Alternatively, you can pass a ``Callable`` supported by
:meth:`~xarray.DataArray.reduce` or a list with ``strings``, ``callables``
or ``tuples`` in a ``(name, func, {kwargs})`` format, where ``func`` can be
a string or a callable.
name : Hashable, optional
Name of the dimension that will hold the ``geometry``, by default "geometry"
index : bool, optional
If ``geometry`` is a :class:`~geopandas.GeoSeries`, ``index=True`` will attach its index as another
coordinate to the geometry dimension in the resulting object. If
``index=None``, the index will be stored if the `geometry.index` is a named
or non-default index. If ``index=False``, it will never be stored. This is
useful as an attribute link between the resulting array and the GeoPandas
object from which the geometry is sourced.
If ``geometry`` is a :class:`~geopandas.GeoSeries`, ``index=True`` will
attach its index as another coordinate to the geometry dimension in the
resulting object. If ``index=None``, the index will be stored if the
`geometry.index` is a named or non-default index. If ``index=False``, it
will never be stored. This is useful as an attribute link between the
resulting array and the GeoPandas object from which the geometry is sourced.
method : str, optional
The method of data extraction. The default is ``"rasterize"``, which uses
:func:`rasterio.features.rasterize` and is faster, but can lead to loss
of information in case of small polygons or lines. Other option is ``"iterate"``, which
iterates over geometries and uses :func:`rasterio.features.geometry_mask`.
:func:`rasterio.features.rasterize` and is faster, but can lead to loss of
information in case of small polygons or lines. Other option is
``"iterate"``, which iterates over geometries and uses
:func:`rasterio.features.geometry_mask`.
all_touched : bool, optional
If True, all pixels touched by geometries will be considered. If False, only
pixels whose center is within the polygon or that are selected by
Bresenham’s line algorithm will be considered.
n_jobs : int, optional
Number of parallel threads to use. It is recommended to set this to the
number of physical cores of the CPU. ``-1`` uses all available cores. Applies
only if ``method="iterate"``.
number of physical cores of the CPU. ``-1`` uses all available cores.
Applies only if ``method="iterate"``.
**kwargs : optional
Keyword arguments to be passed to the aggregation function
(e.g., ``Dataset.quantile(**kwargs)``).
Expand All @@ -990,8 +993,6 @@ def zonal_stats(
the :class:`GeometryIndex` of ``geometry``.
"""
# TODO: allow multiple stats at the same time (concat along a new axis),
# TODO: possibly as a list of tuples to include names?
if method == "rasterize":
result = _zonal_stats_rasterize(
self,
Expand Down Expand Up @@ -1033,9 +1034,7 @@ def zonal_stats(
result = result.assign_coords({index_name: (name, geometry.index)})

# standardize the shape - each method comes with a different one
return result.transpose(
name, *tuple(d for d in self._obj.dims if d not in [x_coords, y_coords])
)
return result.transpose(name, ...)

def extract_points(
self,
Expand Down
61 changes: 61 additions & 0 deletions xvec/tests/test_zonal_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,64 @@ def test_callable(method):
world.geometry, "longitude", "latitude", method=method, stats="std"
)
xr.testing.assert_identical(da_agg, da_std)


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
def test_multiple(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
result = ds.xvec.zonal_stats(
world.geometry[:10].boundary,
"longitude",
"latitude",
stats=[
"mean",
"sum",
("quantile", "quantile", {"q": [0.1, 0.2, 0.3]}),
("numpymean", np.nanmean),
np.nanmean,
],
method=method,
n_jobs=1,
)
assert sorted(result.dims) == sorted(
[
"level",
"zonal_statistics",
"geometry",
"month",
"quantile",
]
)

assert (
result.zonal_statistics == ["mean", "sum", "quantile", "numpymean", "nanmean"]
).all()


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
def test_invalid(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
with pytest.raises(ValueError, match=r"\['gorilla'\] is not a valid aggregation."):
ds.xvec.zonal_stats(
world.geometry[:10].boundary,
"longitude",
"latitude",
stats=[
"mean",
["gorilla"],
],
method=method,
n_jobs=1,
)

with pytest.raises(ValueError, match="3 is not a valid aggregation."):
ds.xvec.zonal_stats(
world.geometry[:10].boundary,
"longitude",
"latitude",
stats=3,
method=method,
n_jobs=1,
)
77 changes: 65 additions & 12 deletions xvec/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,31 @@
from typing import Callable

import numpy as np
import pandas as pd
import shapely
import xarray as xr


def _agg_rasterize(groups, stats, **kwargs):
if isinstance(stats, str):
return getattr(groups, stats)(**kwargs)
return groups.reduce(stats, keep_attrs=True, **kwargs)


def _agg_iterate(masked, stats, x_coords, y_coords, **kwargs):
if isinstance(stats, str):
return getattr(masked, stats)(
dim=(y_coords, x_coords), keep_attrs=True, **kwargs
)
return masked.reduce(stats, dim=(y_coords, x_coords), keep_attrs=True, **kwargs)


def _zonal_stats_rasterize(
acc,
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str | Callable = "mean",
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
name: str = "geometry",
all_touched: bool = False,
**kwargs,
Expand Down Expand Up @@ -47,10 +62,31 @@ def _zonal_stats_rasterize(
all_touched=all_touched,
)
groups = acc._obj.groupby(xr.DataArray(labels, dims=(y_coords, x_coords)))
if isinstance(stats, str):
agg = getattr(groups, stats)(**kwargs)

if pd.api.types.is_list_like(stats):
agg = {}
for stat in stats:
if isinstance(stat, str):
agg[stat] = _agg_rasterize(groups, stat, **kwargs)
elif callable(stat):
agg[stat.__name__] = _agg_rasterize(groups, stat, **kwargs)
elif isinstance(stat, tuple):
kws = stat[2] if len(stat) == 3 else {}
agg[stat[0]] = _agg_rasterize(groups, stat[1], **kws)
else:
raise ValueError(f"{stat} is not a valid aggregation.")

agg = xr.concat(
agg.values(),
dim=xr.DataArray(
list(agg.keys()), name="zonal_statistics", dims="zonal_statistics"
),
)
elif isinstance(stats, str) or callable(stats):
agg = _agg_rasterize(groups, stats, **kwargs)
else:
agg = groups.reduce(stats, keep_attrs=True, **kwargs)
raise ValueError(f"{stats} is not a valid aggregation.")

vec_cube = (
agg.reindex(group=range(len(geometry)))
.assign_coords(group=geometry)
Expand All @@ -68,7 +104,7 @@ def _zonal_stats_iterative(
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str | Callable = "mean",
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
name: str = "geometry",
all_touched: bool = False,
n_jobs: int = -1,
Expand Down Expand Up @@ -168,7 +204,7 @@ def _agg_geom(
trans,
x_coords: str = None,
y_coords: str = None,
stats: str | Callable = "mean",
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
all_touched=False,
**kwargs,
):
Expand Down Expand Up @@ -216,14 +252,31 @@ def _agg_geom(
all_touched=all_touched,
)
masked = acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords)))
if isinstance(stats, str):
result = getattr(masked, stats)(
dim=(y_coords, x_coords), keep_attrs=True, **kwargs
if pd.api.types.is_list_like(stats):
agg = {}
for stat in stats:
if isinstance(stat, str):
agg[stat] = _agg_iterate(masked, stat, x_coords, y_coords, **kwargs)
elif callable(stat):
agg[stat.__name__] = _agg_iterate(
masked, stat, x_coords, y_coords, **kwargs
)
elif isinstance(stat, tuple):
kws = stat[2] if len(stat) == 3 else {}
agg[stat[0]] = _agg_iterate(masked, stat[1], x_coords, y_coords, **kws)
else:
raise ValueError(f"{stat} is not a valid aggregation.")

result = xr.concat(
agg.values(),
dim=xr.DataArray(
list(agg.keys()), name="zonal_statistics", dims="zonal_statistics"
),
)
elif isinstance(stats, str) or callable(stats):
result = _agg_iterate(masked, stats, x_coords, y_coords, **kwargs)
else:
result = masked.reduce(
stats, dim=(y_coords, x_coords), keep_attrs=True, **kwargs
)
raise ValueError(f"{stats} is not a valid aggregation.")

del mask
gc.collect()
Expand Down

0 comments on commit e86d3ec

Please sign in to comment.