diff --git a/docs/API/basedataset.md b/docs/API/basedataset.md index c5357a2..cdaeee7 100644 --- a/docs/API/basedataset.md +++ b/docs/API/basedataset.md @@ -1 +1 @@ -::: openqdc.datasets.base \ No newline at end of file +::: openqdc.datasets.base diff --git a/docs/API/formats.md b/docs/API/formats.md index 77a6f21..fab9816 100644 --- a/docs/API/formats.md +++ b/docs/API/formats.md @@ -1 +1 @@ -::: openqdc.datasets.structure \ No newline at end of file +::: openqdc.datasets.structure diff --git a/mkdocs.yml b/mkdocs.yml index 70eba3a..5df290d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -10,7 +10,7 @@ use_directory_urls: false docs_dir: "docs" # Fail on warnings to detect issues with types and docstring -strict: false +strict: true nav: - Overview: index.md diff --git a/openqdc/datasets/base.py b/openqdc/datasets/base.py index 6851e5e..ebfb68f 100644 --- a/openqdc/datasets/base.py +++ b/openqdc/datasets/base.py @@ -1,7 +1,11 @@ """The BaseDataset defining shared functionality between all datasets.""" import os -from collections import Iterable + +try: + from collections.abc import Iterable +except ImportError: + from collections import Iterable from functools import partial from itertools import compress from os.path import join as p_join @@ -102,7 +106,7 @@ def __init__( recompute_statistics: bool = False, transform: Optional[Callable] = None, read_as_zarr: bool = False, - regressor_kwargs={ + regressor_kwargs: dict = { "solver_type": "linear", "sub_sample": None, "stride": 1, @@ -394,7 +398,9 @@ def collate_list(self, list_entries: List[Dict]) -> dict: return res - def save_preprocess(self, data_dict, upload=False, overwrite=True, as_zarr: bool = False): + def save_preprocess( + self, data_dict: dict[str, np.ndarray], upload: bool = False, overwrite: bool = True, as_zarr: bool = False + ): """ Save the preprocessed data to the cache directory and optionally upload it to the remote storage. @@ -448,7 +454,7 @@ def _convert_on_loading(self, x, key): else: return x - def is_preprocessed(self): + def is_preprocessed(self) -> bool: """ Check if the dataset is preprocessed and available online or locally. @@ -511,7 +517,7 @@ def upload(self, overwrite: bool = False, as_zarr: bool = False): local_path = p_join(self.preprocess_path, "props.pkl" if not as_zarr else "metadata.zip") push_remote(local_path, overwrite=overwrite) - def save_xyz(self, idx: int, energy_method: int = 0, path: Optional[str] = None, ext=True): + def save_xyz(self, idx: int, energy_method: int = 0, path: Optional[str] = None, ext: bool = True): """ Save a single entry at index idx as an extxyz file. @@ -548,7 +554,7 @@ def to_xyz(self, energy_method: int = 0, path: Optional[str] = None): ): write_extxyz(f, atoms, append=True) - def get_ase_atoms(self, idx: int, energy_method: int = 0, ext=True) -> Atoms: + def get_ase_atoms(self, idx: int, energy_method: int = 0, ext: bool = True) -> Atoms: """ Get the ASE atoms object for the entry at index idx. diff --git a/openqdc/datasets/structure.py b/openqdc/datasets/structure.py index c2672bb..bccb1e2 100644 --- a/openqdc/datasets/structure.py +++ b/openqdc/datasets/structure.py @@ -63,9 +63,8 @@ def save_preprocess( preprocess_path: path to the preprocessed data file data_keys: list of keys to load from the data file data_dict: dictionary of data to save - data_shapes: dictionary of shapes for each key extra_data_keys: list of keys to load from the extra data file - overwrite: whether to overwrite the local cache + extra_data_types: dictionary of data types for each key """ raise NotImplementedError @@ -87,7 +86,6 @@ def load_extra_files( preprocess_path: path to the preprocessed data file data_keys: list of keys to load from the data file pkl_data_keys: list of keys to load from the extra files - extra_data_keys: list of keys to load from the extra data file overwrite: whether to overwrite the local cache """ raise NotImplementedError @@ -97,7 +95,7 @@ def join_and_ext(self, path: Union[str, PathLike], filename: str) -> Union[str, Join a path and a filename and add the correct extension. Parameters: - Path: the path to join + path: the path to join filename: the filename to join Returns: diff --git a/openqdc/utils/regressor.py b/openqdc/utils/regressor.py index a980ef5..8313621 100644 --- a/openqdc/utils/regressor.py +++ b/openqdc/utils/regressor.py @@ -79,8 +79,8 @@ def __init__( stride: int = 1, subsample: Optional[Union[float, int]] = None, remove_nan: bool = True, - *args, - **kwargs, + *args: any, + **kwargs: any, ): """ Regressor class for preparing and solving regression problem for isolated atom energies. @@ -117,7 +117,7 @@ def __init__( self._post_init() @classmethod - def from_openqdc_dataset(cls, dataset, *args, **kwargs) -> "Regressor": + def from_openqdc_dataset(cls, dataset: any, *args: any, **kwargs: any) -> "Regressor": """ Initialize the regressor object from an openqdc dataset. This is the default method. *args and and **kwargs are passed to the __init__ method and depends on the specific regressor.