Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
abarciauskas-bgse committed Nov 15, 2024
1 parent f36adf2 commit e922ccd
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 18 deletions.
4 changes: 3 additions & 1 deletion virtualizarr/manifests/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from .manifest import ChunkManifest

if TYPE_CHECKING:
from zarr import Array # type: ignore

from .array import ManifestArray


Expand All @@ -25,7 +27,7 @@ def decorator(func):
return decorator


def check_combineable_zarr_arrays(arrays: Iterable["ManifestArray"]) -> None:
def check_combineable_zarr_arrays(arrays: Iterable["ManifestArray" | "Array"]) -> None:
"""
The downside of the ManifestArray approach compared to the VirtualZarrArray concatenation proposal is that
the result must also be a single valid zarr array, implying that the inputs must have the same dtype, codec etc.
Expand Down
27 changes: 15 additions & 12 deletions virtualizarr/tests/test_writers/test_icechunk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from itertools import product
from pathlib import Path
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, cast

import pytest

Expand All @@ -11,6 +11,7 @@
from xarray import Dataset, concat, open_dataset, open_zarr
from xarray.core.variable import Variable
from zarr import Array, Group, group # type: ignore[import-untyped]
from zarr.core.metadata import ArrayV3Metadata

from virtualizarr.manifests import ChunkManifest, ManifestArray
from virtualizarr.writers.icechunk import dataset_to_icechunk, generate_chunk_key
Expand Down Expand Up @@ -69,7 +70,8 @@ def test_write_new_virtual_variable(
# assert dict(arr.attrs) == {"units": "km"}

# check dimensions
assert arr.metadata.dimension_names == ("x", "y")
if isinstance(arr.metadata, ArrayV3Metadata):
assert arr.metadata.dimension_names == ("x", "y")


def test_set_single_virtual_ref_without_encoding(
Expand Down Expand Up @@ -361,17 +363,17 @@ def generate_chunk_manifest(

def gen_virtual_dataset(
file_uri: str,
shape: tuple[int, int] = (3, 4),
chunk_shape: tuple[int, int] = (3, 4),
shape: tuple[int, ...] = (3, 4),
chunk_shape: tuple[int, ...] = (3, 4),
dtype: np.dtype = np.dtype("int32"),
compressor: dict = None,
filters: str = None,
fill_value: str = None,
encoding: dict = None,
compressor: Optional[dict] = None,
filters: Optional[list[dict[Any, Any]]] = None,
fill_value: Optional[str] = None,
encoding: Optional[dict] = None,
variable_name: str = "foo",
base_offset: int = 6144,
length: int = 48,
dims: list[str] = None,
dims: Optional[list[str]] = None,
):
manifest = generate_chunk_manifest(
file_uri,
Expand All @@ -391,7 +393,8 @@ def gen_virtual_dataset(
)
ma = ManifestArray(chunkmanifest=manifest, zarray=zarray)
ds = open_dataset(file_uri)
dims = dims or ds.sizes.keys()
ds_dims: list[str] = cast(list[str], list(ds.dims))
dims = dims or ds_dims
var = Variable(
data=ma,
dims=dims,
Expand Down Expand Up @@ -441,7 +444,7 @@ def test_append_virtual_ref_without_encoding(

## When appending to a virtual ref with encoding, it succeeds
def test_append_virtual_ref_with_encoding(
self, icechunk_storage: "StorageConfig", netcdf4_files_factory: callable
self, icechunk_storage: "StorageConfig", netcdf4_files_factory: Callable
):
import xarray.testing as xrt
from icechunk import IcechunkStore
Expand Down Expand Up @@ -496,7 +499,7 @@ def test_append_virtual_ref_with_encoding(

## When appending to a virtual ref with compression, it succeeds
def test_append_with_compression_succeeds(
self, icechunk_storage: "StorageConfig", netcdf4_files_factory: callable
self, icechunk_storage: "StorageConfig", netcdf4_files_factory: Callable
):
import xarray.testing as xrt
from icechunk import IcechunkStore
Expand Down
18 changes: 13 additions & 5 deletions virtualizarr/writers/icechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,19 @@ def resize_array(
append_axis: int,
) -> "Array":
existing_array = group[name]
if not isinstance(existing_array, Array):
raise ValueError("Expected existing array to be a zarr.core.Array")
new_shape = list(existing_array.shape)
new_shape[append_axis] += var.shape[append_axis]
return existing_array.resize(tuple(new_shape))


def get_axis(
dims: list[str],
dim_name: str,
dim_name: Optional[str],
) -> int:
if dim_name is None:
raise ValueError("dim_name must be provided")
return dims.index(dim_name)


Expand All @@ -150,7 +154,7 @@ def _check_compatible_arrays(
manifest_api.check_same_shapes_except_on_concat_axis(arr_shapes, append_axis)


def check_compatible_encodings(encoding1, encoding2):
def _check_compatible_encodings(encoding1, encoding2):
for key, value in encoding1.items():
if key in encoding2:
if encoding2[key] != value:
Expand All @@ -171,17 +175,19 @@ def write_virtual_variable_to_icechunk(
zarray = ma.zarray
mode = store.mode.str

dims = var.dims
dims: list[str] = cast(list[str], list(var.dims))
append_axis, existing_num_chunks, arr = None, None, None
if append_dim and append_dim not in dims:
raise ValueError(
f"append_dim {append_dim} not found in variable dimensions {dims}"
)
if mode == "a":
existing_array = group[name]
if not isinstance(existing_array, Array):
raise ValueError("Expected existing array to be a zarr.core.Array")
append_axis = get_axis(dims, append_dim)
# check if arrays can be concatenated
check_compatible_encodings(var.encoding, existing_array.attrs)
_check_compatible_encodings(var.encoding, existing_array.attrs)
_check_compatible_arrays(ma, existing_array, append_axis)

# determine number of existing chunks along the append axis
Expand Down Expand Up @@ -239,7 +245,9 @@ def generate_chunk_key(
f"append_axis {append_axis} is greater than the number of indices {len(index)}"
)
return "/".join(
str(ind + existing_num_chunks) if axis is append_axis else str(ind)
str(ind + existing_num_chunks)
if axis is append_axis and existing_num_chunks is not None
else str(ind)
for axis, ind in enumerate(index)
)

Expand Down

0 comments on commit e922ccd

Please sign in to comment.