Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jul 1, 2024
1 parent 678977a commit 6973805
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
30 changes: 17 additions & 13 deletions xvec/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,10 @@ def encode_cf(self) -> xr.Dataset:

for crs, grid_mapping in grid_mappings.items():
grid_mapping_attrs = crs.to_cf()
# TODO: not all CRS can be represented by CF grid_mappings
# For now, we allow this.
# if "grid_mapping_name" not in grid_mapping_attrs:
# raise ValueError
wkt_str = crs.to_wkt()
grid_mapping_attrs["spatial_ref"] = wkt_str
grid_mapping_attrs["crs_wkt"] = wkt_str
Expand All @@ -1302,25 +1306,25 @@ def decode_cf(self) -> xr.Dataset:
import cf_xarray as cfxr

decoded = cfxr.geometry.decode_geometries(self._obj.copy())
try:
# TODO: handle multiple CRS
grid_mapping = self._obj.cf["grid_mapping"]
crs = CRS.from_cf(grid_mapping.attrs)
except KeyError:
crs = None

crs = {
name: CRS.from_cf(var.attrs)
for name, var in decoded._variables.items()
if "crs_wkt" in var.attrs or "grid_mapping_name" in var.attrs
}
dims = decoded.xvec.geom_coords.dims
for dim in dims:
decoded = (
decoded.set_xindex(dim) if dim not in decoded._indexes else decoded
)
decoded = decoded.xvec.set_geom_indexes(dim, crs=crs)
if crs:
decoded = decoded.xvec.set_geom_indexes(
dim, crs=crs.get(decoded[dim].attrs.get("grid_mapping", None))
)
for name in crs:
# remove spatial_ref so the coordinate system is only stored on the index
del decoded[grid_mapping.name]
for var in decoded._variables.values():
if set(dims) & set(var.dims):
var.attrs.pop("grid_mapping", None)
del decoded[name]
for var in decoded._variables.values():
if set(dims) & set(var.dims):
var.attrs.pop("grid_mapping", None)
return decoded


Expand Down
16 changes: 16 additions & 0 deletions xvec/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,21 @@ def multi_geom_dataset(geom_array, geom_array_z):
)


@pytest.fixture(scope="session")
def multi_geom_multi_crs_dataset(geom_array, geom_array_z):
return (
xr.Dataset(
coords={
"geom": geom_array,
"geom_z": geom_array_z,
}
)
.drop_indexes(["geom", "geom_z"])
.set_xindex("geom", GeometryIndex, crs=26915)
.set_xindex("geom_z", GeometryIndex, crs="EPSG:4362")
)


@pytest.fixture(scope="session")
def multi_geom_no_index_dataset(geom_array, geom_array_z):
return (
Expand Down Expand Up @@ -165,6 +180,7 @@ def traffic_dataset(geom_array):
"multi_dataset",
"multi_geom_dataset",
"multi_geom_no_index_dataset",
"multi_geom_multi_crs_dataset",
"traffic_dataset",
],
scope="session",
Expand Down
3 changes: 2 additions & 1 deletion xvec/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,8 @@ def test_cf_roundtrip(all_datasets):
if unique_crs := {
idx.crs for idx in ds.xvec.geom_coords_indexed.xindexes.values() if idx.crs
}:
assert len(unique_crs) == len(encoded.cf[["grid_mapping"]])
nwkts = sum(1 for var in encoded._variables.values() if "crs_wkt" in var.attrs)
assert len(unique_crs) == nwkts

roundtripped = encoded.xvec.decode_cf()

Expand Down

0 comments on commit 6973805

Please sign in to comment.