diff --git a/doc/whats-new.rst b/doc/whats-new.rst index db32de3c9cd..80ea29746f8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -224,6 +224,10 @@ New Features - Use `opt_einsum `_ for :py:func:`xarray.dot` by default if installed. By `Deepak Cherian `_. (:issue:`7764`, :pull:`8373`). +- Decode/Encode netCDF4 enums and store the enum definition in dataarrays' dtype metadata. + If multiple variables share the same enum in netCDF4, each dataarray will have its own + enum definition in their respective dtype metadata. + By `Abel Aoun _`(:issue:`8144`, :pull:`8147`) - Add ``DataArray.dt.total_seconds()`` method to match the Pandas API. (:pull:`8435`). By `Ben Mares `_. - Allow passing ``region="auto"`` in :py:meth:`Dataset.to_zarr` to automatically infer the diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index cf753828242..d3845568709 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -49,7 +49,6 @@ # string used by netCDF4. _endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"} - NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) @@ -141,7 +140,9 @@ def _check_encoding_dtype_is_vlen_string(dtype): ) -def _get_datatype(var, nc_format="NETCDF4", raise_on_invalid_encoding=False): +def _get_datatype( + var, nc_format="NETCDF4", raise_on_invalid_encoding=False +) -> np.dtype: if nc_format == "NETCDF4": return _nc4_dtype(var) if "dtype" in var.encoding: @@ -234,13 +235,13 @@ def _force_native_endianness(var): def _extract_nc4_variable_encoding( - variable, + variable: Variable, raise_on_invalid=False, lsd_okay=True, h5py_okay=False, backend="netCDF4", unlimited_dims=None, -): +) -> dict[str, Any]: if unlimited_dims is None: unlimited_dims = () @@ -308,7 +309,7 @@ def _extract_nc4_variable_encoding( return encoding -def _is_list_of_strings(value): +def _is_list_of_strings(value) -> bool: arr = np.asarray(value) return arr.dtype.kind in ["U", "S"] and arr.size > 1 @@ -414,13 +415,25 @@ def _acquire(self, needs_lock=True): def ds(self): return self._acquire() - def open_store_variable(self, name, var): + def open_store_variable(self, name: str, var): + import netCDF4 + dimensions = var.dimensions - data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) attributes = {k: var.getncattr(k) for k in var.ncattrs()} + data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) + encoding: dict[str, Any] = {} + if isinstance(var.datatype, netCDF4.EnumType): + encoding["dtype"] = np.dtype( + data.dtype, + metadata={ + "enum": var.datatype.enum_dict, + "enum_name": var.datatype.name, + }, + ) + else: + encoding["dtype"] = var.dtype _ensure_fill_value_valid(data, attributes) # netCDF4 specific encoding; save _FillValue for later - encoding = {} filters = var.filters() if filters is not None: encoding.update(filters) @@ -440,7 +453,6 @@ def open_store_variable(self, name, var): # save source so __repr__ can detect if it's local or not encoding["source"] = self._filename encoding["original_shape"] = var.shape - encoding["dtype"] = var.dtype return Variable(dimensions, data, attributes, encoding) @@ -485,21 +497,24 @@ def encode_variable(self, variable): return variable def prepare_variable( - self, name, variable, check_encoding=False, unlimited_dims=None + self, name, variable: Variable, check_encoding=False, unlimited_dims=None ): _ensure_no_forward_slash_in_name(name) - + attrs = variable.attrs.copy() + fill_value = attrs.pop("_FillValue", None) datatype = _get_datatype( variable, self.format, raise_on_invalid_encoding=check_encoding ) - attrs = variable.attrs.copy() - - fill_value = attrs.pop("_FillValue", None) - + # check enum metadata and use netCDF4.EnumType + if ( + (meta := np.dtype(datatype).metadata) + and (e_name := meta.get("enum_name")) + and (e_dict := meta.get("enum")) + ): + datatype = self._build_and_get_enum(name, datatype, e_name, e_dict) encoding = _extract_nc4_variable_encoding( variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims ) - if name in self.ds.variables: nc4_var = self.ds.variables[name] else: @@ -527,6 +542,33 @@ def prepare_variable( return target, variable.data + def _build_and_get_enum( + self, var_name: str, dtype: np.dtype, enum_name: str, enum_dict: dict[str, int] + ) -> Any: + """ + Add or get the netCDF4 Enum based on the dtype in encoding. + The return type should be ``netCDF4.EnumType``, + but we avoid importing netCDF4 globally for performances. + """ + if enum_name not in self.ds.enumtypes: + return self.ds.createEnumType( + dtype, + enum_name, + enum_dict, + ) + datatype = self.ds.enumtypes[enum_name] + if datatype.enum_dict != enum_dict: + error_msg = ( + f"Cannot save variable `{var_name}` because an enum" + f" `{enum_name}` already exists in the Dataset but have" + " a different definition. To fix this error, make sure" + " each variable have a uniquely named enum in their" + " `encoding['dtype'].metadata` or, if they should share" + " the same enum type, make sure the enums are identical." + ) + raise ValueError(error_msg) + return datatype + def sync(self): self.ds.sync() diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 54383394b06..c3d57ad1903 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -566,7 +566,7 @@ def decode(self): class ObjectVLenStringCoder(VariableCoder): def encode(self): - return NotImplementedError + raise NotImplementedError def decode(self, variable: Variable, name: T_Name = None) -> Variable: if variable.dtype == object and variable.encoding.get("dtype", False) == str: @@ -574,3 +574,22 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: return variable else: return variable + + +class NativeEnumCoder(VariableCoder): + """Encode Enum into variable dtype metadata.""" + + def encode(self, variable: Variable, name: T_Name = None) -> Variable: + if ( + "dtype" in variable.encoding + and np.dtype(variable.encoding["dtype"]).metadata + and "enum" in variable.encoding["dtype"].metadata + ): + dims, data, attrs, encoding = unpack_for_encoding(variable) + data = data.astype(dtype=variable.encoding.pop("dtype")) + return Variable(dims, data, attrs, encoding, fastpath=True) + else: + return variable + + def decode(self, variable: Variable, name: T_Name = None) -> Variable: + raise NotImplementedError() diff --git a/xarray/conventions.py b/xarray/conventions.py index 94e285d8e1d..1d8e81e1bf2 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -48,10 +48,6 @@ T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore] -def _var_as_tuple(var: Variable) -> T_VarTuple: - return var.dims, var.data, var.attrs.copy(), var.encoding.copy() - - def _infer_dtype(array, name=None): """Given an object array with no missing values, infer its dtype from all elements.""" if array.dtype.kind != "O": @@ -111,7 +107,7 @@ def _copy_with_dtype(data, dtype: np.typing.DTypeLike): def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable: # TODO: move this from conventions to backends? (it's not CF related) if var.dtype.kind == "O": - dims, data, attrs, encoding = _var_as_tuple(var) + dims, data, attrs, encoding = variables.unpack_for_encoding(var) # leave vlen dtypes unchanged if strings.check_vlen_dtype(data.dtype) is not None: @@ -162,7 +158,7 @@ def encode_cf_variable( var: Variable, needs_copy: bool = True, name: T_Name = None ) -> Variable: """ - Converts an Variable into an Variable which follows some + Converts a Variable into a Variable which follows some of the CF conventions: - Nans are masked using _FillValue (or the deprecated missing_value) @@ -188,6 +184,7 @@ def encode_cf_variable( variables.CFScaleOffsetCoder(), variables.CFMaskCoder(), variables.UnsignedIntegerCoder(), + variables.NativeEnumCoder(), variables.NonStringCoder(), variables.DefaultFillvalueCoder(), variables.BooleanCoder(), @@ -447,7 +444,7 @@ def stackable(dim: Hashable) -> bool: decode_timedelta=decode_timedelta, ) except Exception as e: - raise type(e)(f"Failed to decode variable {k!r}: {e}") + raise type(e)(f"Failed to decode variable {k!r}: {e}") from e if decode_coords in [True, "coordinates", "all"]: var_attrs = new_vars[k].attrs if "coordinates" in var_attrs: @@ -633,7 +630,11 @@ def cf_decoder( decode_cf_variable """ variables, attributes, _ = decode_cf_variables( - variables, attributes, concat_characters, mask_and_scale, decode_times + variables, + attributes, + concat_characters, + mask_and_scale, + decode_times, ) return variables, attributes diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d2ea0f8a1a4..e6e65c73a53 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4062,6 +4062,9 @@ def to_netcdf( name is the same as a coordinate name, then it is given the name ``"__xarray_dataarray_variable__"``. + [netCDF4 backend only] netCDF4 enums are decoded into the + dataarray dtype metadata. + See Also -------- Dataset.to_netcdf diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 104b6d0867d..d01cfd7ff55 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1704,6 +1704,126 @@ def test_raise_on_forward_slashes_in_names(self) -> None: with self.roundtrip(ds): pass + @requires_netCDF4 + def test_encoding_enum__no_fill_value(self): + with create_tmp_file() as tmp_file: + cloud_type_dict = {"clear": 0, "cloudy": 1} + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("time", size=2) + cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) + v = nc.createVariable( + "clouds", + cloud_type, + "time", + fill_value=None, + ) + v[:] = 1 + with open_dataset(tmp_file) as original: + save_kwargs = {} + if self.engine == "h5netcdf": + save_kwargs["invalid_netcdf"] = True + with self.roundtrip(original, save_kwargs=save_kwargs) as actual: + assert_equal(original, actual) + assert ( + actual.clouds.encoding["dtype"].metadata["enum"] + == cloud_type_dict + ) + if self.engine != "h5netcdf": + # not implemented in h5netcdf yet + assert ( + actual.clouds.encoding["dtype"].metadata["enum_name"] + == "cloud_type" + ) + + @requires_netCDF4 + def test_encoding_enum__multiple_variable_with_enum(self): + with create_tmp_file() as tmp_file: + cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255} + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("time", size=2) + cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) + nc.createVariable( + "clouds", + cloud_type, + "time", + fill_value=255, + ) + nc.createVariable( + "tifa", + cloud_type, + "time", + fill_value=255, + ) + with open_dataset(tmp_file) as original: + save_kwargs = {} + if self.engine == "h5netcdf": + save_kwargs["invalid_netcdf"] = True + with self.roundtrip(original, save_kwargs=save_kwargs) as actual: + assert_equal(original, actual) + assert ( + actual.clouds.encoding["dtype"] == actual.tifa.encoding["dtype"] + ) + assert ( + actual.clouds.encoding["dtype"].metadata + == actual.tifa.encoding["dtype"].metadata + ) + assert ( + actual.clouds.encoding["dtype"].metadata["enum"] + == cloud_type_dict + ) + if self.engine != "h5netcdf": + # not implemented in h5netcdf yet + assert ( + actual.clouds.encoding["dtype"].metadata["enum_name"] + == "cloud_type" + ) + + @requires_netCDF4 + def test_encoding_enum__error_multiple_variable_with_changing_enum(self): + """ + Given 2 variables, if they share the same enum type, + the 2 enum definition should be identical. + """ + with create_tmp_file() as tmp_file: + cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255} + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("time", size=2) + cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) + nc.createVariable( + "clouds", + cloud_type, + "time", + fill_value=255, + ) + nc.createVariable( + "tifa", + cloud_type, + "time", + fill_value=255, + ) + with open_dataset(tmp_file) as original: + assert ( + original.clouds.encoding["dtype"].metadata + == original.tifa.encoding["dtype"].metadata + ) + modified_enum = original.clouds.encoding["dtype"].metadata["enum"] + modified_enum.update({"neblig": 2}) + original.clouds.encoding["dtype"] = np.dtype( + "u1", + metadata={"enum": modified_enum, "enum_name": "cloud_type"}, + ) + if self.engine != "h5netcdf": + # not implemented yet in h5netcdf + with pytest.raises( + ValueError, + match=( + "Cannot save variable .*" + " because an enum `cloud_type` already exists in the Dataset .*" + ), + ): + with self.roundtrip(original): + pass + @requires_netCDF4 class TestNetCDF4Data(NetCDF4Base):