diff --git a/openqdc/datasets/structure.py b/openqdc/datasets/structure.py index 8adfce4..f6dc077 100644 --- a/openqdc/datasets/structure.py +++ b/openqdc/datasets/structure.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from os import PathLike from os.path import join as p_join -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import zarr @@ -108,7 +108,7 @@ def load_data( preprocess_path: Union[str, PathLike], data_keys: List[str], data_types: Dict[str, np.dtype], - data_shapes: Dict[str, tuple[int, int]], + data_shapes: Dict[str, Tuple[int, int]], extra_data_keys: List[str], overwrite: bool, ):