diff --git a/doc/whats-new.rst b/doc/whats-new.rst
index db32de3c9cd..80ea29746f8 100644
--- a/doc/whats-new.rst
+++ b/doc/whats-new.rst
@@ -224,6 +224,10 @@ New Features
- Use `opt_einsum `_ for :py:func:`xarray.dot` by default if installed.
By `Deepak Cherian `_. (:issue:`7764`, :pull:`8373`).
+- Decode/Encode netCDF4 enums and store the enum definition in dataarrays' dtype metadata.
+ If multiple variables share the same enum in netCDF4, each dataarray will have its own
+ enum definition in their respective dtype metadata.
+ By `Abel Aoun _`(:issue:`8144`, :pull:`8147`)
- Add ``DataArray.dt.total_seconds()`` method to match the Pandas API. (:pull:`8435`).
By `Ben Mares `_.
- Allow passing ``region="auto"`` in :py:meth:`Dataset.to_zarr` to automatically infer the
diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py
index cf753828242..d3845568709 100644
--- a/xarray/backends/netCDF4_.py
+++ b/xarray/backends/netCDF4_.py
@@ -49,7 +49,6 @@
# string used by netCDF4.
_endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"}
-
NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK])
@@ -141,7 +140,9 @@ def _check_encoding_dtype_is_vlen_string(dtype):
)
-def _get_datatype(var, nc_format="NETCDF4", raise_on_invalid_encoding=False):
+def _get_datatype(
+ var, nc_format="NETCDF4", raise_on_invalid_encoding=False
+) -> np.dtype:
if nc_format == "NETCDF4":
return _nc4_dtype(var)
if "dtype" in var.encoding:
@@ -234,13 +235,13 @@ def _force_native_endianness(var):
def _extract_nc4_variable_encoding(
- variable,
+ variable: Variable,
raise_on_invalid=False,
lsd_okay=True,
h5py_okay=False,
backend="netCDF4",
unlimited_dims=None,
-):
+) -> dict[str, Any]:
if unlimited_dims is None:
unlimited_dims = ()
@@ -308,7 +309,7 @@ def _extract_nc4_variable_encoding(
return encoding
-def _is_list_of_strings(value):
+def _is_list_of_strings(value) -> bool:
arr = np.asarray(value)
return arr.dtype.kind in ["U", "S"] and arr.size > 1
@@ -414,13 +415,25 @@ def _acquire(self, needs_lock=True):
def ds(self):
return self._acquire()
- def open_store_variable(self, name, var):
+ def open_store_variable(self, name: str, var):
+ import netCDF4
+
dimensions = var.dimensions
- data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
attributes = {k: var.getncattr(k) for k in var.ncattrs()}
+ data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self))
+ encoding: dict[str, Any] = {}
+ if isinstance(var.datatype, netCDF4.EnumType):
+ encoding["dtype"] = np.dtype(
+ data.dtype,
+ metadata={
+ "enum": var.datatype.enum_dict,
+ "enum_name": var.datatype.name,
+ },
+ )
+ else:
+ encoding["dtype"] = var.dtype
_ensure_fill_value_valid(data, attributes)
# netCDF4 specific encoding; save _FillValue for later
- encoding = {}
filters = var.filters()
if filters is not None:
encoding.update(filters)
@@ -440,7 +453,6 @@ def open_store_variable(self, name, var):
# save source so __repr__ can detect if it's local or not
encoding["source"] = self._filename
encoding["original_shape"] = var.shape
- encoding["dtype"] = var.dtype
return Variable(dimensions, data, attributes, encoding)
@@ -485,21 +497,24 @@ def encode_variable(self, variable):
return variable
def prepare_variable(
- self, name, variable, check_encoding=False, unlimited_dims=None
+ self, name, variable: Variable, check_encoding=False, unlimited_dims=None
):
_ensure_no_forward_slash_in_name(name)
-
+ attrs = variable.attrs.copy()
+ fill_value = attrs.pop("_FillValue", None)
datatype = _get_datatype(
variable, self.format, raise_on_invalid_encoding=check_encoding
)
- attrs = variable.attrs.copy()
-
- fill_value = attrs.pop("_FillValue", None)
-
+ # check enum metadata and use netCDF4.EnumType
+ if (
+ (meta := np.dtype(datatype).metadata)
+ and (e_name := meta.get("enum_name"))
+ and (e_dict := meta.get("enum"))
+ ):
+ datatype = self._build_and_get_enum(name, datatype, e_name, e_dict)
encoding = _extract_nc4_variable_encoding(
variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims
)
-
if name in self.ds.variables:
nc4_var = self.ds.variables[name]
else:
@@ -527,6 +542,33 @@ def prepare_variable(
return target, variable.data
+ def _build_and_get_enum(
+ self, var_name: str, dtype: np.dtype, enum_name: str, enum_dict: dict[str, int]
+ ) -> Any:
+ """
+ Add or get the netCDF4 Enum based on the dtype in encoding.
+ The return type should be ``netCDF4.EnumType``,
+ but we avoid importing netCDF4 globally for performances.
+ """
+ if enum_name not in self.ds.enumtypes:
+ return self.ds.createEnumType(
+ dtype,
+ enum_name,
+ enum_dict,
+ )
+ datatype = self.ds.enumtypes[enum_name]
+ if datatype.enum_dict != enum_dict:
+ error_msg = (
+ f"Cannot save variable `{var_name}` because an enum"
+ f" `{enum_name}` already exists in the Dataset but have"
+ " a different definition. To fix this error, make sure"
+ " each variable have a uniquely named enum in their"
+ " `encoding['dtype'].metadata` or, if they should share"
+ " the same enum type, make sure the enums are identical."
+ )
+ raise ValueError(error_msg)
+ return datatype
+
def sync(self):
self.ds.sync()
diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py
index 54383394b06..c3d57ad1903 100644
--- a/xarray/coding/variables.py
+++ b/xarray/coding/variables.py
@@ -566,7 +566,7 @@ def decode(self):
class ObjectVLenStringCoder(VariableCoder):
def encode(self):
- return NotImplementedError
+ raise NotImplementedError
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
if variable.dtype == object and variable.encoding.get("dtype", False) == str:
@@ -574,3 +574,22 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
return variable
else:
return variable
+
+
+class NativeEnumCoder(VariableCoder):
+ """Encode Enum into variable dtype metadata."""
+
+ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
+ if (
+ "dtype" in variable.encoding
+ and np.dtype(variable.encoding["dtype"]).metadata
+ and "enum" in variable.encoding["dtype"].metadata
+ ):
+ dims, data, attrs, encoding = unpack_for_encoding(variable)
+ data = data.astype(dtype=variable.encoding.pop("dtype"))
+ return Variable(dims, data, attrs, encoding, fastpath=True)
+ else:
+ return variable
+
+ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
+ raise NotImplementedError()
diff --git a/xarray/conventions.py b/xarray/conventions.py
index 94e285d8e1d..1d8e81e1bf2 100644
--- a/xarray/conventions.py
+++ b/xarray/conventions.py
@@ -48,10 +48,6 @@
T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore]
-def _var_as_tuple(var: Variable) -> T_VarTuple:
- return var.dims, var.data, var.attrs.copy(), var.encoding.copy()
-
-
def _infer_dtype(array, name=None):
"""Given an object array with no missing values, infer its dtype from all elements."""
if array.dtype.kind != "O":
@@ -111,7 +107,7 @@ def _copy_with_dtype(data, dtype: np.typing.DTypeLike):
def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable:
# TODO: move this from conventions to backends? (it's not CF related)
if var.dtype.kind == "O":
- dims, data, attrs, encoding = _var_as_tuple(var)
+ dims, data, attrs, encoding = variables.unpack_for_encoding(var)
# leave vlen dtypes unchanged
if strings.check_vlen_dtype(data.dtype) is not None:
@@ -162,7 +158,7 @@ def encode_cf_variable(
var: Variable, needs_copy: bool = True, name: T_Name = None
) -> Variable:
"""
- Converts an Variable into an Variable which follows some
+ Converts a Variable into a Variable which follows some
of the CF conventions:
- Nans are masked using _FillValue (or the deprecated missing_value)
@@ -188,6 +184,7 @@ def encode_cf_variable(
variables.CFScaleOffsetCoder(),
variables.CFMaskCoder(),
variables.UnsignedIntegerCoder(),
+ variables.NativeEnumCoder(),
variables.NonStringCoder(),
variables.DefaultFillvalueCoder(),
variables.BooleanCoder(),
@@ -447,7 +444,7 @@ def stackable(dim: Hashable) -> bool:
decode_timedelta=decode_timedelta,
)
except Exception as e:
- raise type(e)(f"Failed to decode variable {k!r}: {e}")
+ raise type(e)(f"Failed to decode variable {k!r}: {e}") from e
if decode_coords in [True, "coordinates", "all"]:
var_attrs = new_vars[k].attrs
if "coordinates" in var_attrs:
@@ -633,7 +630,11 @@ def cf_decoder(
decode_cf_variable
"""
variables, attributes, _ = decode_cf_variables(
- variables, attributes, concat_characters, mask_and_scale, decode_times
+ variables,
+ attributes,
+ concat_characters,
+ mask_and_scale,
+ decode_times,
)
return variables, attributes
diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py
index d2ea0f8a1a4..e6e65c73a53 100644
--- a/xarray/core/dataarray.py
+++ b/xarray/core/dataarray.py
@@ -4062,6 +4062,9 @@ def to_netcdf(
name is the same as a coordinate name, then it is given the name
``"__xarray_dataarray_variable__"``.
+ [netCDF4 backend only] netCDF4 enums are decoded into the
+ dataarray dtype metadata.
+
See Also
--------
Dataset.to_netcdf
diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py
index 104b6d0867d..d01cfd7ff55 100644
--- a/xarray/tests/test_backends.py
+++ b/xarray/tests/test_backends.py
@@ -1704,6 +1704,126 @@ def test_raise_on_forward_slashes_in_names(self) -> None:
with self.roundtrip(ds):
pass
+ @requires_netCDF4
+ def test_encoding_enum__no_fill_value(self):
+ with create_tmp_file() as tmp_file:
+ cloud_type_dict = {"clear": 0, "cloudy": 1}
+ with nc4.Dataset(tmp_file, mode="w") as nc:
+ nc.createDimension("time", size=2)
+ cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
+ v = nc.createVariable(
+ "clouds",
+ cloud_type,
+ "time",
+ fill_value=None,
+ )
+ v[:] = 1
+ with open_dataset(tmp_file) as original:
+ save_kwargs = {}
+ if self.engine == "h5netcdf":
+ save_kwargs["invalid_netcdf"] = True
+ with self.roundtrip(original, save_kwargs=save_kwargs) as actual:
+ assert_equal(original, actual)
+ assert (
+ actual.clouds.encoding["dtype"].metadata["enum"]
+ == cloud_type_dict
+ )
+ if self.engine != "h5netcdf":
+ # not implemented in h5netcdf yet
+ assert (
+ actual.clouds.encoding["dtype"].metadata["enum_name"]
+ == "cloud_type"
+ )
+
+ @requires_netCDF4
+ def test_encoding_enum__multiple_variable_with_enum(self):
+ with create_tmp_file() as tmp_file:
+ cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255}
+ with nc4.Dataset(tmp_file, mode="w") as nc:
+ nc.createDimension("time", size=2)
+ cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
+ nc.createVariable(
+ "clouds",
+ cloud_type,
+ "time",
+ fill_value=255,
+ )
+ nc.createVariable(
+ "tifa",
+ cloud_type,
+ "time",
+ fill_value=255,
+ )
+ with open_dataset(tmp_file) as original:
+ save_kwargs = {}
+ if self.engine == "h5netcdf":
+ save_kwargs["invalid_netcdf"] = True
+ with self.roundtrip(original, save_kwargs=save_kwargs) as actual:
+ assert_equal(original, actual)
+ assert (
+ actual.clouds.encoding["dtype"] == actual.tifa.encoding["dtype"]
+ )
+ assert (
+ actual.clouds.encoding["dtype"].metadata
+ == actual.tifa.encoding["dtype"].metadata
+ )
+ assert (
+ actual.clouds.encoding["dtype"].metadata["enum"]
+ == cloud_type_dict
+ )
+ if self.engine != "h5netcdf":
+ # not implemented in h5netcdf yet
+ assert (
+ actual.clouds.encoding["dtype"].metadata["enum_name"]
+ == "cloud_type"
+ )
+
+ @requires_netCDF4
+ def test_encoding_enum__error_multiple_variable_with_changing_enum(self):
+ """
+ Given 2 variables, if they share the same enum type,
+ the 2 enum definition should be identical.
+ """
+ with create_tmp_file() as tmp_file:
+ cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255}
+ with nc4.Dataset(tmp_file, mode="w") as nc:
+ nc.createDimension("time", size=2)
+ cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict)
+ nc.createVariable(
+ "clouds",
+ cloud_type,
+ "time",
+ fill_value=255,
+ )
+ nc.createVariable(
+ "tifa",
+ cloud_type,
+ "time",
+ fill_value=255,
+ )
+ with open_dataset(tmp_file) as original:
+ assert (
+ original.clouds.encoding["dtype"].metadata
+ == original.tifa.encoding["dtype"].metadata
+ )
+ modified_enum = original.clouds.encoding["dtype"].metadata["enum"]
+ modified_enum.update({"neblig": 2})
+ original.clouds.encoding["dtype"] = np.dtype(
+ "u1",
+ metadata={"enum": modified_enum, "enum_name": "cloud_type"},
+ )
+ if self.engine != "h5netcdf":
+ # not implemented yet in h5netcdf
+ with pytest.raises(
+ ValueError,
+ match=(
+ "Cannot save variable .*"
+ " because an enum `cloud_type` already exists in the Dataset .*"
+ ),
+ ):
+ with self.roundtrip(original):
+ pass
+
@requires_netCDF4
class TestNetCDF4Data(NetCDF4Base):