Skip to content

Commit

Permalink
methods to update attrs and encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
marceloandrioni committed Sep 5, 2024
1 parent b0745fb commit 76de0c9
Showing 1 changed file with 187 additions and 52 deletions.
239 changes: 187 additions & 52 deletions src/load_by_step/_load_by_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import itertools
import time
from collections.abc import Mapping, Iterable
from typing import Annotated, Any
from typing import Annotated, Any, Literal
from pathlib import Path

import numpy as np
Expand Down Expand Up @@ -133,35 +133,110 @@ def _check_available_memory(self) -> None:

raise MemoryError(err_msg)

def update_attrs(self, attrs: Mapping[Any, Any] | None = None) -> None:
attrs = {} if attrs is None else attrs
self.dx.attrs = {**self.dx.attrs, **attrs}
@validate_func_args_and_return
def change_attrs(
self,
*,
mode: Literal["w", "a"] = "a",
**attrs_kwargs: Any,
) -> xr.Dataset | xr.DataArray:
"""Change attributes of Dataset/DataArray.
def update_encoding(self, encoding: Mapping[Any, Any] | None = None) -> None:
encoding = {} if encoding is None else encoding
self.dx.encoding = {**self.dx.encoding, **encoding}
Parameters
----------
mode : Literal["w", "a"], optional
'a' to append attribute to existing attributes or 'w' to replace the
full existing attributes.
**attrs_kwargs : Any
def to_localfile(
Returns
-------
Dataset/DataArray
Examples
--------
>>> ds = xr.tutorial.open_dataset("air_temperature_gradient")
>>> da = ds["Tair"]
>>> da2 = da.lbs.change_attrs(long_name="foo", units="bar")
"""

# use a copy so the state of the object is preserved
dx = self.dx.copy()

match mode:
case "w":
dx.attrs = attrs_kwargs
case "a":
dx.attrs = {**dx.attrs, **attrs_kwargs}
case _:
err_msg = ("Invalid mode. Must be 'w' (write) or 'a' (append).")
raise ValueError(err_msg)

return dx

@validate_func_args_and_return
def change_encoding(
self,
outfile: Annotated[Path, Field(strict=False)],
**kwargs: Any,
) -> None:
*,
mode: Literal["w", "a"] = "a",
**encoding_kwargs: Any,
) -> xr.Dataset | xr.DataArray:
"""Change encoding of Dataset/DataArray.
match outfile.suffix:
Parameters
----------
mode : Literal["w", "a"], optional
'a' to append encoding to existing encoding or 'w' to replace the
full existing encoding.
**attrs_kwargs : Any
Returns
-------
Dataset/DataArray
Examples
--------
>>> ds = xr.tutorial.open_dataset("air_temperature_gradient")
>>> da = ds["Tair"]
>>> da2 = da.lbs.change_encoding(dtype="f4", contiguous=False, zlib=True, complevel=1)
case ".nc":
self.dx.to_netcdf(outfile, **kwargs)
"""

case ".zarr":
self.dx.to_zarr(outfile, **kwargs)
# use a copy so the state of the object is preserved
dx = self.dx.copy()

match mode:
case "w":
dx.encoding = encoding_kwargs
case "a":
dx.encoding = {**dx.encoding, **encoding_kwargs}
case _:
err_msg = (
"Invalid file extenion. File extension must be"
" .nc or .zarr"
)
err_msg = ("Invalid mode. Must be 'w' (write) or 'a' (append).")
raise ValueError(err_msg)

return dx

def to_localfile(
self,
outfile: Annotated[Path, Field(strict=False)],
**kwargs: Any,
) -> None:

outputs = {
".nc": self.dx.to_netcdf,
".zarr": self.dx.to_zarr,
}

try:
outputs[outfile.suffix](outfile, **kwargs)
except KeyError as err:
err_msg = (
"Invalid file extenion. File extension must be one of: "
+ ", ".join(outputs)
)
raise ValueError(err_msg)


@xr.register_dataarray_accessor("lbs")
class DALoadByStep(DsDaMixin):
Expand Down Expand Up @@ -461,7 +536,7 @@ def load_by_step(

self._check_dims(dims_and_steps.keys())

# use a copy so the state of self.ds is preserved
# use a copy so the state of the object is preserved
ds = self.ds.copy()

# apply load for each data variable
Expand All @@ -485,14 +560,98 @@ def load_by_step(

return ds

@validate_func_args_and_return
def change_attrs(
self,
*,
mode: Literal["w", "a"] = "a",
variables_attrs: Mapping[Any, Mapping[Any, Any]] | None = None,
**attrs_kwargs,
) -> xr.Dataset:
"""Change attributes of Dataset and internal DataArrays.
Parameters
----------
mode : Literal["w", "a"], optional
'a' to append attribute to existing attributes or 'w' to replace the
full existing attributes.
variables_attrs: dict of dicts, optional
Dict with variables names as keys and variables attrs as values.
**attrs_kwargs : Any
Returns
-------
Dataset
Examples
--------
>>> ds = xr.tutorial.open_dataset("air_temperature_gradient")
>>> variables_attrs = {"Tair": {"long_name": "My variable"}}
>>> ds2 = ds.lbs.change_attrs(title="My Dataset",
... variables_attrs=variables_attrs)
"""

# use a copy so the state of the object is preserved
ds = self.ds.copy()

ds = super().change_attrs(mode=mode, **attrs_kwargs) # type: ignore

variables_attrs = {} if variables_attrs is None else variables_attrs
for var, attrs in variables_attrs.items():
ds[var] = ds[var].lbs.change_attrs(mode=mode, **attrs)

return ds

@validate_func_args_and_return
def change_encoding(
self,
*,
mode: Literal["w", "a"] = "a",
variables_encoding: Mapping[Any, Mapping[Any, Any]] | None = None,
**encoding_kwargs,
) -> xr.Dataset:
"""Change encoding of Dataset and internal DataArrays.
Parameters
----------
mode : Literal["w", "a"], optional
'a' to append encoding to existing encoding or 'w' to replace the
full existing encoding.
variables_attrs: dict of dicts, optional
Dict with variables names as keys and variables encoding as values.
**attrs_kwargs : Any
Returns
-------
Dataset
Examples
--------
>>> ds = xr.tutorial.open_dataset("air_temperature_gradient")
>>> variables_encoding = {"Tair": {"dtype": "f4"}}
>>> ds2 = ds.lbs.change_encoding(unlimited_dims="time",
unvariables_encoding=variables_encoding)
"""

# use a copy so the state of the object is preserved
ds = self.ds.copy()

ds = super().change_encoding(mode=mode, **encoding_kwargs) # type: ignore

variables_encoding = {} if variables_encoding is None else variables_encoding
for var, encoding in variables_encoding.items():
ds[var] = ds[var].lbs.change_encoding(mode=mode, **encoding)

return ds

@validate_func_args_and_return
def load_and_save_by_step(
self,
*,
outfile: Annotated[NewPath, Field(strict=False)],
indexers: Mapping[Any, PositiveInt] | None = None,
attrs: Mapping[Any, Mapping[Any, Any]] | None = None,
encoding: Mapping[Any, Mapping[Any, Any]] | None = None,
to_outfile_kwargs: Mapping[Any, Any] | None = None,
seconds_between_requests: NonNegativeFloat = 0,
**indexers_kwargs: PositiveInt,
Expand All @@ -513,15 +672,9 @@ def load_and_save_by_step(
integers.
One of indexers or indexers_kwargs must be provided.
outfile : Path, str
File to save the data into.
attrs : dict, optional
A dict of dicts where the key is a variable name and the value is
a dict with key:value attrs pairs. To set the dataset attrs use
variable "global".
encoding : dict, optional
A dict of dicts where the key is a variable name and the value is
a dict with key:value encoding pairs. To set the dataset encoding
use variable "global".
File to save the data into. Must have .nc or .zarr extension.
to_outfile: dict, optional
Dict with kwargs passed to .to_netcdf()/.to_zarr() methods.
seconds_between_requests : float, optional
Wait time in seconds between requests.
**indexers_kwargs : {dim: indexer, ...}, optional
Expand All @@ -538,17 +691,8 @@ def load_and_save_by_step(
purpose. A real application would be to read data from a THREDDS server.
>>> ds = xr.tutorial.open_dataset("air_temperature_gradient")
>>> attrs = {
... "global": {"title": "My example dataset"},
... "Tair": {"long_name": "My example variable"},
... }
>>> encoding = {
... "Tair": {"dtype": "f4", "zlib": "True", "deflate": 1},
... }
>>> ds.lbs.load_and_save_by_step(time=500,
lon=30,
attrs=attrs,
encoding=encoding,
outfile="/tmp/foo.nc",
seconds_between_requests=1)
Expand All @@ -561,21 +705,12 @@ def load_and_save_by_step(

self._check_dims(dims_and_steps.keys())

attrs = {} if attrs is None else attrs
encoding = {} if encoding is None else encoding
to_outfile_kwargs = {} if to_outfile_kwargs is None else to_outfile_kwargs

# use a copy so the state of self.ds is preserved
# use a copy so the state of the object is preserved
ds = self.ds.copy()

# update Dataset and DataArray attrs and encoding
ds.lbs.update_attrs(attrs.get("global", None))
ds.lbs.update_encoding(encoding.get("global", None))
for var in list(ds.data_vars):
ds[var].lbs.update_attrs(attrs.get(var, None))
ds[var].lbs.update_encoding(encoding.get(var, None))

# create empty Dataset
# save empty Dataset
ds[[]].lbs.to_localfile(outfile, mode="w", **to_outfile_kwargs)

# apply load for each data variable
Expand Down

0 comments on commit 76de0c9

Please sign in to comment.