Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add encode_cf, decode_cf #69

Merged
merged 21 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,11 @@ dmypy.json

# sphinx
doc/source/generated
doc/source/geo-encoded*

# ruff
.ruff_cache
doc/source/cube.joblib.compressed
doc/source/cube.pickle

cache/
cache/
4 changes: 3 additions & 1 deletion doc/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ Methods
Dataset.xvec.to_geopandas
Dataset.xvec.extract_points
Dataset.xvec.zonal_stats
Dataset.xvec.encode_cf
Dataset.xvec.decode_cf


DataArray.xvec
Expand Down Expand Up @@ -91,4 +93,4 @@ Methods
DataArray.xvec.to_geodataframe
DataArray.xvec.to_geopandas
DataArray.xvec.extract_points
DataArray.xvec.zonal_stats
DataArray.xvec.zonal_stats
142 changes: 81 additions & 61 deletions doc/source/io.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dependencies:
# required
- shapely=2
- xarray
- cf_xarray
# testing
- pytest
- pytest-cov
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"xarray >= 2022.12.0",
"pyproj >= 3.0.0",
"shapely >= 2.0b1",
"cf_xarray >= 0.9.2",
]

[project.urls]
Expand Down
122 changes: 121 additions & 1 deletion xvec/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def geom_coords(self) -> Mapping[Hashable, xr.DataArray]:
).coords

@property
def geom_coords_indexed(self) -> Mapping[Hashable, xr.DataArray]:
def geom_coords_indexed(self) -> xr.Coordinates:
"""Returns a dictionary of xarray.DataArray objects corresponding to
coordinate variables using :class:`~xvec.GeometryIndex`.

Expand Down Expand Up @@ -1258,6 +1258,126 @@ def extract_points(
)
return result

def encode_cf(self) -> xr.Dataset:
"""
Encode all geometry variables and associated CRS with CF conventions.

Use this method prior to writing an Xarray dataset to any array format
(e.g. netCDF or Zarr).

The following invariant is satisfied:
``assert ds.xvec.encode_cf().xvec.decode_cf().identical(ds) is True``

CRS information on the ``GeometryIndex`` is encoded using CF's ``grid_mapping`` convention.

This function uses ``cf_xarray.geometry.encode_geometries`` under the hood and will only
work on Datasets.

Returns
-------
Dataset
"""
import cf_xarray as cfxr

if not isinstance(self._obj, xr.Dataset):
raise ValueError(
"CF encoding is only valid on Datasets. Convert to a dataset using `.to_dataset()` first."
)

ds = self._obj.copy()
coords = self.geom_coords_indexed

# TODO: this could use geoxarray, but is quite simple in any case
# Adapted from rioxarray
# 1. First find all unique CRS objects
# preserve ordering for roundtripping
unique_crs = []
for _, xi in sorted(coords.xindexes.items()):
if xi.crs not in unique_crs:
unique_crs.append(xi.crs)
if len(unique_crs) == 1:
grid_mappings = {unique_crs.pop(): "spatial_ref"}
else:
grid_mappings = {
crs_: f"spatial_ref_{i}" for i, crs_ in enumerate(unique_crs)
}

# 2. Convert CRS to grid_mapping variables and assign them
for crs, grid_mapping in grid_mappings.items():
grid_mapping_attrs = crs.to_cf()
# TODO: not all CRS can be represented by CF grid_mappings
# For now, we allow this.
# if "grid_mapping_name" not in grid_mapping_attrs:
# raise ValueError
wkt_str = crs.to_wkt()
grid_mapping_attrs["spatial_ref"] = wkt_str
grid_mapping_attrs["crs_wkt"] = wkt_str
ds.coords[grid_mapping] = xr.Variable(
dims=(), data=0, attrs=grid_mapping_attrs
)

# 3. Associate other variables with appropriate grid_mapping variable
# We asumme that this relation follows from dimension names being shared between
# the GeometryIndex and the variable being checked.
for name, coord in coords.items():
dims = set(coord.dims)
index = coords.xindexes[name]
varnames = (k for k, v in ds._variables.items() if dims & set(v.dims))
for name in varnames:
if TYPE_CHECKING:
assert isinstance(index, GeometryIndex)
ds._variables[name].attrs["grid_mapping"] = grid_mappings[index.crs]

encoded = cfxr.geometry.encode_geometries(ds)
return encoded

def decode_cf(self) -> xr.Dataset:
"""
Decode geometries stored as CF-compliant arrays to shapely geometries.

The following invariant is satisfied:
``assert ds.xvec.encode_cf().xvec.decode_cf().identical(ds) is True``


A ``GeometryIndex`` is created automatically and CRS information, if available
following CF's ``grid_mapping`` convention, will be associated with the ``GeometryIndex``.

This function uses ``cf_xarray.geometry.decode_geometries`` under the hood, and will only
work on Datasets.

Returns
-------
Dataset
"""
import cf_xarray as cfxr

if not isinstance(self._obj, xr.Dataset):
raise ValueError(
"CF decoding is only supported on Datasets. Convert to a Dataset using `.to_dataset()` first."
)

decoded = cfxr.geometry.decode_geometries(self._obj.copy())
crs = {
name: CRS.from_user_input(var.attrs["crs_wkt"])
for name, var in decoded._variables.items()
if "crs_wkt" in var.attrs or "grid_mapping_name" in var.attrs
}
dims = decoded.xvec.geom_coords.dims
for dim in dims:
decoded = (
decoded.set_xindex(dim) if dim not in decoded._indexes else decoded
)
decoded = decoded.xvec.set_geom_indexes(
dim, crs=crs.get(decoded[dim].attrs.get("grid_mapping", None))
)
Comment on lines +1370 to +1372
Copy link
Contributor Author

@dcherian dcherian Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the key buggy line. it always sets the index, we do not record which geometry dims were indexed at encode-time. What should we do here?

As an aside it'd be nice for set_geom_indexes to understand the grid_mapping convention. WDYT?

One approach: decode_cf does NOT set the new index, but the user does so manually. Instead set_geom_indexes learns how to interpret the grid_mapping convention so CRS is set properly by default.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it always sets the index, we do not record which geometry dims were indexed at encode-time. What should we do here?

Is that an issue if we just index all geom dims encoded in the file?

As an aside it'd be nice for set_geom_indexes to understand the grid_mapping convention. WDYT?

Not against but I don't really know what would it mean implementation-wise. Maybe just a simple call to pyproj.CRS.from_cf?

set_geom_indexes learns how to interpret the grid_mapping convention so CRS is set properly by default.

That would be preferable. Not a fan of asking users to set indexes after reading what already was indexed before writing.

for name in crs:
# remove spatial_ref so the coordinate system is only stored on the index
del decoded[name]
for var in decoded._variables.values():
if set(dims) & set(var.dims):
var.attrs.pop("grid_mapping", None)
return decoded


def _resolve_input(
positional: Mapping[Any, Any] | None,
Expand Down
43 changes: 41 additions & 2 deletions xvec/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def multi_dataset(geom_array, geom_array_z):

@pytest.fixture(scope="session")
def multi_geom_dataset(geom_array, geom_array_z):
return (
ds = (
xr.Dataset(
coords={
"geom": geom_array,
Expand All @@ -80,11 +80,32 @@ def multi_geom_dataset(geom_array, geom_array_z):
.set_xindex("geom", GeometryIndex, crs=26915)
.set_xindex("geom_z", GeometryIndex, crs=26915)
)
ds["geom"].attrs["crs"] = ds.xindexes["geom"].crs
ds["geom_z"].attrs["crs"] = ds.xindexes["geom_z"].crs
Comment on lines +83 to +84
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you can't set these in GeometryIndex.create_variables?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from "no one thought about that until now", I am not aware of any.

return ds


@pytest.fixture(scope="session")
def multi_geom_multi_crs_dataset(geom_array, geom_array_z):
ds = (
xr.Dataset(
coords={
"geom": geom_array,
"geom_z": geom_array_z,
}
)
.drop_indexes(["geom", "geom_z"])
.set_xindex("geom", GeometryIndex, crs=26915)
.set_xindex("geom_z", GeometryIndex, crs="EPSG:4362")
)
ds["geom"].attrs["crs"] = ds.xindexes["geom"].crs
ds["geom_z"].attrs["crs"] = ds.xindexes["geom_z"].crs
return ds


@pytest.fixture(scope="session")
def multi_geom_no_index_dataset(geom_array, geom_array_z):
return (
ds = (
xr.Dataset(
coords={
"geom": geom_array,
Expand All @@ -96,6 +117,9 @@ def multi_geom_no_index_dataset(geom_array, geom_array_z):
.set_xindex("geom", GeometryIndex, crs=26915)
.set_xindex("geom_z", GeometryIndex, crs=26915)
)
ds["geom"].attrs["crs"] = ds.xindexes["geom"].crs
ds["geom_z"].attrs["crs"] = ds.xindexes["geom_z"].crs
return ds


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -157,3 +181,18 @@ def traffic_dataset(geom_array):
"day": pd.date_range("2023-01-01", periods=10),
},
).xvec.set_geom_indexes(["origin", "destination"], crs=26915)


@pytest.fixture(
params=[
"first_geom_dataset",
"multi_dataset",
"multi_geom_dataset",
"multi_geom_no_index_dataset",
"multi_geom_multi_crs_dataset",
"traffic_dataset",
],
scope="session",
)
def all_datasets(request):
return request.getfixturevalue(request.param)
31 changes: 31 additions & 0 deletions xvec/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,3 +674,34 @@ def test_extract_points_array():
geometry=4326
),
)


def test_cf_roundtrip(all_datasets):
ds = all_datasets
copy = ds.copy(deep=True)
encoded = ds.xvec.encode_cf()

if unique_crs := {
idx.crs for idx in ds.xvec.geom_coords_indexed.xindexes.values() if idx.crs
}:
nwkts = sum(1 for var in encoded._variables.values() if "crs_wkt" in var.attrs)
assert len(unique_crs) == nwkts
roundtripped = encoded.xvec.decode_cf()

xr.testing.assert_identical(ds, roundtripped)
assert_indexes_equals(ds, roundtripped)
# make sure we didn't modify the original dataset.
xr.testing.assert_identical(ds, copy)


def assert_indexes_equals(left, right):
# Till https://github.com/pydata/xarray/issues/5812 is resolved
# Also, we don't record whether an unindexed coordinate was serialized
# So just asssert that the left ("expected") dataset has fewer indexes
# than the right.
# This isn't great...
assert sorted(left.xindexes.keys()) <= sorted(right.xindexes.keys())
for k in left.xindexes:
if not isinstance(left.xindexes[k], GeometryIndex):
continue
assert left.xindexes[k].equals(right.xindexes[k])
Loading