From 697380594d8612d368e35e27f94baa974d8e4602 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 1 Jul 2024 17:17:09 -0600 Subject: [PATCH] Updates --- xvec/accessor.py | 30 +++++++++++++++++------------- xvec/tests/conftest.py | 16 ++++++++++++++++ xvec/tests/test_accessor.py | 3 ++- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/xvec/accessor.py b/xvec/accessor.py index e6e4e3a..3c1f1d8 100644 --- a/xvec/accessor.py +++ b/xvec/accessor.py @@ -1282,6 +1282,10 @@ def encode_cf(self) -> xr.Dataset: for crs, grid_mapping in grid_mappings.items(): grid_mapping_attrs = crs.to_cf() + # TODO: not all CRS can be represented by CF grid_mappings + # For now, we allow this. + # if "grid_mapping_name" not in grid_mapping_attrs: + # raise ValueError wkt_str = crs.to_wkt() grid_mapping_attrs["spatial_ref"] = wkt_str grid_mapping_attrs["crs_wkt"] = wkt_str @@ -1302,25 +1306,25 @@ def decode_cf(self) -> xr.Dataset: import cf_xarray as cfxr decoded = cfxr.geometry.decode_geometries(self._obj.copy()) - try: - # TODO: handle multiple CRS - grid_mapping = self._obj.cf["grid_mapping"] - crs = CRS.from_cf(grid_mapping.attrs) - except KeyError: - crs = None - + crs = { + name: CRS.from_cf(var.attrs) + for name, var in decoded._variables.items() + if "crs_wkt" in var.attrs or "grid_mapping_name" in var.attrs + } dims = decoded.xvec.geom_coords.dims for dim in dims: decoded = ( decoded.set_xindex(dim) if dim not in decoded._indexes else decoded ) - decoded = decoded.xvec.set_geom_indexes(dim, crs=crs) - if crs: + decoded = decoded.xvec.set_geom_indexes( + dim, crs=crs.get(decoded[dim].attrs.get("grid_mapping", None)) + ) + for name in crs: # remove spatial_ref so the coordinate system is only stored on the index - del decoded[grid_mapping.name] - for var in decoded._variables.values(): - if set(dims) & set(var.dims): - var.attrs.pop("grid_mapping", None) + del decoded[name] + for var in decoded._variables.values(): + if set(dims) & set(var.dims): + var.attrs.pop("grid_mapping", None) return decoded diff --git a/xvec/tests/conftest.py b/xvec/tests/conftest.py index dda08d5..0680f44 100644 --- a/xvec/tests/conftest.py +++ b/xvec/tests/conftest.py @@ -82,6 +82,21 @@ def multi_geom_dataset(geom_array, geom_array_z): ) +@pytest.fixture(scope="session") +def multi_geom_multi_crs_dataset(geom_array, geom_array_z): + return ( + xr.Dataset( + coords={ + "geom": geom_array, + "geom_z": geom_array_z, + } + ) + .drop_indexes(["geom", "geom_z"]) + .set_xindex("geom", GeometryIndex, crs=26915) + .set_xindex("geom_z", GeometryIndex, crs="EPSG:4362") + ) + + @pytest.fixture(scope="session") def multi_geom_no_index_dataset(geom_array, geom_array_z): return ( @@ -165,6 +180,7 @@ def traffic_dataset(geom_array): "multi_dataset", "multi_geom_dataset", "multi_geom_no_index_dataset", + "multi_geom_multi_crs_dataset", "traffic_dataset", ], scope="session", diff --git a/xvec/tests/test_accessor.py b/xvec/tests/test_accessor.py index d8638b4..db00725 100644 --- a/xvec/tests/test_accessor.py +++ b/xvec/tests/test_accessor.py @@ -684,7 +684,8 @@ def test_cf_roundtrip(all_datasets): if unique_crs := { idx.crs for idx in ds.xvec.geom_coords_indexed.xindexes.values() if idx.crs }: - assert len(unique_crs) == len(encoded.cf[["grid_mapping"]]) + nwkts = sum(1 for var in encoded._variables.values() if "crs_wkt" in var.attrs) + assert len(unique_crs) == nwkts roundtripped = encoded.xvec.decode_cf()