diff --git a/pathwaysutils/persistence/helper.py b/pathwaysutils/persistence/helper.py index 9153844..56514ca 100644 --- a/pathwaysutils/persistence/helper.py +++ b/pathwaysutils/persistence/helper.py @@ -172,7 +172,7 @@ def get_read_request( def get_bulk_read_request( location_path: str, names: str, - dtypes: jnp.dtype, + dtypes: np.dtype, shapes: Sequence[Sequence[int]], shardings: Sequence[jax.sharding.Sharding], devices: Sequence[jax.Device],