Skip to content

Commit

Permalink
Strict docs generation, completed type hinting
Browse files Browse the repository at this point in the history
  • Loading branch information
FNTwin committed Jul 16, 2024
1 parent 7ec4164 commit 150842a
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 16 deletions.
2 changes: 1 addition & 1 deletion docs/API/basedataset.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
::: openqdc.datasets.base
::: openqdc.datasets.base
2 changes: 1 addition & 1 deletion docs/API/formats.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
::: openqdc.datasets.structure
::: openqdc.datasets.structure
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use_directory_urls: false
docs_dir: "docs"

# Fail on warnings to detect issues with types and docstring
strict: false
strict: true

nav:
- Overview: index.md
Expand Down
18 changes: 12 additions & 6 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""The BaseDataset defining shared functionality between all datasets."""

import os
from collections import Iterable

try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
from functools import partial
from itertools import compress
from os.path import join as p_join
Expand Down Expand Up @@ -102,7 +106,7 @@ def __init__(
recompute_statistics: bool = False,
transform: Optional[Callable] = None,
read_as_zarr: bool = False,
regressor_kwargs={
regressor_kwargs: dict = {
"solver_type": "linear",
"sub_sample": None,
"stride": 1,
Expand Down Expand Up @@ -394,7 +398,9 @@ def collate_list(self, list_entries: List[Dict]) -> dict:

return res

def save_preprocess(self, data_dict, upload=False, overwrite=True, as_zarr: bool = False):
def save_preprocess(
self, data_dict: dict[str, np.ndarray], upload: bool = False, overwrite: bool = True, as_zarr: bool = False
):
"""
Save the preprocessed data to the cache directory and optionally upload it to the remote storage.
Expand Down Expand Up @@ -448,7 +454,7 @@ def _convert_on_loading(self, x, key):
else:
return x

def is_preprocessed(self):
def is_preprocessed(self) -> bool:
"""
Check if the dataset is preprocessed and available online or locally.
Expand Down Expand Up @@ -511,7 +517,7 @@ def upload(self, overwrite: bool = False, as_zarr: bool = False):
local_path = p_join(self.preprocess_path, "props.pkl" if not as_zarr else "metadata.zip")
push_remote(local_path, overwrite=overwrite)

def save_xyz(self, idx: int, energy_method: int = 0, path: Optional[str] = None, ext=True):
def save_xyz(self, idx: int, energy_method: int = 0, path: Optional[str] = None, ext: bool = True):
"""
Save a single entry at index idx as an extxyz file.
Expand Down Expand Up @@ -548,7 +554,7 @@ def to_xyz(self, energy_method: int = 0, path: Optional[str] = None):
):
write_extxyz(f, atoms, append=True)

def get_ase_atoms(self, idx: int, energy_method: int = 0, ext=True) -> Atoms:
def get_ase_atoms(self, idx: int, energy_method: int = 0, ext: bool = True) -> Atoms:
"""
Get the ASE atoms object for the entry at index idx.
Expand Down
6 changes: 2 additions & 4 deletions openqdc/datasets/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def save_preprocess(
preprocess_path: path to the preprocessed data file
data_keys: list of keys to load from the data file
data_dict: dictionary of data to save
data_shapes: dictionary of shapes for each key
extra_data_keys: list of keys to load from the extra data file
overwrite: whether to overwrite the local cache
extra_data_types: dictionary of data types for each key
"""
raise NotImplementedError

Expand All @@ -87,7 +86,6 @@ def load_extra_files(
preprocess_path: path to the preprocessed data file
data_keys: list of keys to load from the data file
pkl_data_keys: list of keys to load from the extra files
extra_data_keys: list of keys to load from the extra data file
overwrite: whether to overwrite the local cache
"""
raise NotImplementedError
Expand All @@ -97,7 +95,7 @@ def join_and_ext(self, path: Union[str, PathLike], filename: str) -> Union[str,
Join a path and a filename and add the correct extension.
Parameters:
Path: the path to join
path: the path to join
filename: the filename to join
Returns:
Expand Down
6 changes: 3 additions & 3 deletions openqdc/utils/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def __init__(
stride: int = 1,
subsample: Optional[Union[float, int]] = None,
remove_nan: bool = True,
*args,
**kwargs,
*args: any,
**kwargs: any,
):
"""
Regressor class for preparing and solving regression problem for isolated atom energies.
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
self._post_init()

@classmethod
def from_openqdc_dataset(cls, dataset, *args, **kwargs) -> "Regressor":
def from_openqdc_dataset(cls, dataset: any, *args: any, **kwargs: any) -> "Regressor":
"""
Initialize the regressor object from an openqdc dataset. This is the default method.
*args and and **kwargs are passed to the __init__ method and depends on the specific regressor.
Expand Down

0 comments on commit 150842a

Please sign in to comment.