Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: support exactextract as a method in zonal_stats #68

Merged
merged 18 commits into from
Jul 2, 2024
3 changes: 3 additions & 0 deletions ci/310.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ dependencies:
- geopandas-base
- geodatasets
- pyogrio
- pip
- pip:
- exactextract==0.2.0.dev0

3 changes: 3 additions & 0 deletions ci/311.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ dependencies:
- geopandas-base
- geodatasets
- pyogrio
- pip
- pip:
- exactextract==0.2.0.dev0

3 changes: 3 additions & 0 deletions ci/312.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,7 @@ dependencies:
- geodatasets
- pyogrio
- mypy
- pip
- pip:
- exactextract==0.2.0.dev0

3 changes: 3 additions & 0 deletions ci/39.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,6 @@ dependencies:
- geopandas-base
- geodatasets
- pyogrio
- pip
- pip:
- exactextract==0.2.0.dev0
1 change: 1 addition & 0 deletions ci/dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ dependencies:
- git+https://github.com/shapely/shapely.git@main
- git+https://github.com/pydata/xarray.git@main
- git+https://github.com/pyproj4/pyproj.git
- exactextract==0.2.0.dev0
16 changes: 15 additions & 1 deletion xvec/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from pyproj import CRS, Transformer

from .index import GeometryIndex
from .zonal import _zonal_stats_iterative, _zonal_stats_rasterize
from .zonal import (
_zonal_stats_exactextract,
_zonal_stats_iterative,
_zonal_stats_rasterize,
)

if TYPE_CHECKING:
from geopandas import GeoDataFrame
Expand Down Expand Up @@ -1088,6 +1092,16 @@ def zonal_stats(
n_jobs=n_jobs,
**kwargs,
)
elif method == "exactextract":
result = _zonal_stats_exactextract(
self,
geometry=geometry,
x_coords=x_coords,
y_coords=y_coords,
stats=stats,
name=name,
**kwargs,
)
else:
raise ValueError(
f"method '{method}' is not supported. Allowed options are 'rasterize' "
Expand Down
123 changes: 77 additions & 46 deletions xvec/tests/test_zonal_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import xvec # noqa: F401


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
def test_structure(method):
da = xr.DataArray(
np.ones((10, 10, 5)),
Expand All @@ -24,14 +24,22 @@ def test_structure(method):
polygon2 = shapely.geometry.Polygon([(6, 22), (9, 22), (9, 29), (6, 26)])
polygons = gpd.GeoSeries([polygon1, polygon2], crs="EPSG:4326")

expected = xr.DataArray(
np.array([[12.0] * 5, [18.0] * 5]),
coords={
"geometry": polygons,
"time": pd.date_range("2023-01-01", periods=5),
},
).xvec.set_geom_indexes("geometry", crs="EPSG:4326")

if method == "exactextract":
expected = xr.DataArray(
np.array([[12.0] * 5, [16.5] * 5]),
coords={
"geometry": polygons,
"time": pd.date_range("2023-01-01", periods=5),
},
).xvec.set_geom_indexes("geometry", crs="EPSG:4326")
else:
expected = xr.DataArray(
np.array([[12.0] * 5, [18.0] * 5]),
coords={
"geometry": polygons,
"time": pd.date_range("2023-01-01", periods=5),
},
).xvec.set_geom_indexes("geometry", crs="EPSG:4326")
actual = da.xvec.zonal_stats(polygons, "x", "y", stats="sum", method=method)
xr.testing.assert_identical(actual, expected)

Expand All @@ -43,35 +51,36 @@ def test_structure(method):
)

# dataset
ds = da.to_dataset(name="test")

expected_ds = expected.to_dataset(name="test").set_coords("geometry")
actual_ds = ds.xvec.zonal_stats(polygons, "x", "y", stats="sum", method=method)
xr.testing.assert_identical(actual_ds, expected_ds)

actual_ix_ds = ds.xvec.zonal_stats(
polygons, "x", "y", stats="sum", method=method, index=True
)
xr.testing.assert_identical(
actual_ix_ds, expected_ds.assign_coords({"index": ("geometry", polygons.index)})
)
if method == "rasterize" or method == "iterate":
ds = da.to_dataset(name="test")
expected_ds = expected.to_dataset(name="test").set_coords("geometry")
actual_ds = ds.xvec.zonal_stats(polygons, "x", "y", stats="sum", method=method)
xr.testing.assert_identical(actual_ds, expected_ds)

actual_ix_ds = ds.xvec.zonal_stats(
polygons, "x", "y", stats="sum", method=method, index=True
)
xr.testing.assert_identical(
actual_ix_ds,
expected_ds.assign_coords({"index": ("geometry", polygons.index)}),
)

# named index
polygons.index.name = "my_index"
actual_ix_named = da.xvec.zonal_stats(
polygons, "x", "y", stats="sum", method=method
)
xr.testing.assert_identical(
actual_ix_named,
expected.assign_coords({"my_index": ("geometry", polygons.index)}),
)
actual_ix_names_ds = ds.xvec.zonal_stats(
polygons, "x", "y", stats="sum", method=method
)
xr.testing.assert_identical(
actual_ix_names_ds,
expected_ds.assign_coords({"my_index": ("geometry", polygons.index)}),
)
# named index
polygons.index.name = "my_index"
actual_ix_named = da.xvec.zonal_stats(
polygons, "x", "y", stats="sum", method=method
)
xr.testing.assert_identical(
actual_ix_named,
expected.assign_coords({"my_index": ("geometry", polygons.index)}),
)
actual_ix_names_ds = ds.xvec.zonal_stats(
polygons, "x", "y", stats="sum", method=method
)
xr.testing.assert_identical(
actual_ix_names_ds,
expected_ds.assign_coords({"my_index": ("geometry", polygons.index)}),
)


def test_match():
Expand Down Expand Up @@ -105,7 +114,7 @@ def test_dataset(method):
)


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
def test_dataarray(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
Expand All @@ -115,10 +124,13 @@ def test_dataarray(method):

assert result.shape == (127, 2, 3)
assert result.dims == ("geometry", "month", "level")
assert result.mean() == pytest.approx(61367.76185577)
if method == "exactextract":
assert result.mean() == pytest.approx(61625.53438858)
else:
assert result.mean() == pytest.approx(61367.76185577)


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
@pytest.mark.parametrize("method", ["rasterize", "iterate", "exactextract"])
def test_stat(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
Expand All @@ -129,13 +141,32 @@ def test_stat(method):
median_ = ds.z.xvec.zonal_stats(
world.geometry, "longitude", "latitude", method=method, stats="median"
)
quantile_ = ds.z.xvec.zonal_stats(
world.geometry, "longitude", "latitude", method=method, stats="quantile", q=0.2
)
if method == "exactextract":
quantile_ = ds.z.xvec.zonal_stats(
world.geometry,
"longitude",
"latitude",
method=method,
stats="quantile(q=0.33)",
masawdah marked this conversation as resolved.
Show resolved Hide resolved
)
else:
quantile_ = ds.z.xvec.zonal_stats(
world.geometry,
"longitude",
"latitude",
method=method,
stats="quantile",
q=0.2,
)

assert mean_.mean() == pytest.approx(61367.76185577)
assert median_.mean() == pytest.approx(61370.18563539)
assert quantile_.mean() == pytest.approx(61279.93619836)
if method == "exactextract":
assert mean_.mean() == pytest.approx(61625.53438858)
assert median_.mean() == pytest.approx(61628.67168691)
assert quantile_.mean() == pytest.approx(61576.0883029)
else:
assert mean_.mean() == pytest.approx(61367.76185577)
assert median_.mean() == pytest.approx(61370.18563539)
assert quantile_.mean() == pytest.approx(61279.93619836)


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
Expand Down
87 changes: 87 additions & 0 deletions xvec/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,90 @@ def _agg_geom(
gc.collect()

return result


def _zonal_stats_exactextract(
acc,
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
name: str = "geometry",
**kwargs,
) -> xr.DataArray:
try:
import exactextract
except ImportError as err:
raise ImportError(
"The exactextract package is required for `zonal_stats()`. "
"You can install it using or 'pip install exactextract'."
) from err

try:
import geopandas as gpd
except ImportError as err:
raise ImportError(
"The geopandas package is required for `xvec.to_geodataframe()`. "
"You can install it using 'conda install -c conda-forge geopandas' or "
"'pip install geopandas'."
) from err

if hasattr(geometry, "crs"):
crs = geometry.crs # type: ignore
else:
crs = None

# the input should be xarray.DataArray
if not isinstance(acc._obj, xr.core.dataarray.DataArray):
acc._obj = acc._obj.to_dataarray()

# Get all the dimensions execpt x_coords, y_coords, they will be used to stack the dataarray later
arr_dims = tuple(dim for dim in acc._obj.dims if dim not in [x_coords, y_coords])

# Get the original information to use for unstacking the resulte later
coords_info = {name: geometry}
original_shape = [len(geometry)]
for dim in arr_dims:
original_shape.append(acc._obj[dim].size)
coords_info[dim] = acc._obj[dim].values

# Stack the other dimensions into one dimension called "location"
data = acc._obj.stack(location=arr_dims)
locs = data.location.size

# Check the order of dimensions
data = data.transpose("location", y_coords, x_coords)

# Aggregation result
gdf = gpd.GeoDataFrame(geometry=geometry, crs=crs)
results = exactextract.exact_extract(rast=data, vec=gdf, ops=stats, output="pandas")

# Unstack the results
if pd.api.types.is_list_like(stats):
agg = {}
i = 0
for stat in stats: # type: ignore
df = results.iloc[:, i : i + locs]
# Unstack the result
arr = df.values.reshape(original_shape)
result = xr.DataArray(
arr, coords=coords_info, dims=coords_info.keys()
).xvec.set_geom_indexes(name, crs=crs)
agg[stat] = result
i += locs
vec_cube = xr.concat(
agg.values(),
dim=xr.DataArray(
list(agg.keys()), name="zonal_statistics", dims="zonal_statistics"
),
)
elif isinstance(stats, str):
# Unstack the result
arr = results.values.reshape(original_shape)
vec_cube = xr.DataArray(
arr, coords=coords_info, dims=coords_info.keys()
).xvec.set_geom_indexes(name, crs=crs)
else:
raise ValueError(f"{stats} is not a valid aggregation for exactextract method.")

return vec_cube
Loading