diff --git a/openqdc/cli.py b/openqdc/cli.py index 186f55f..dd9caaa 100644 --- a/openqdc/cli.py +++ b/openqdc/cli.py @@ -19,22 +19,34 @@ def sanitize(dictionary): + """ + Sanitize dataset names to be used in the CLI. + """ return {k.lower().replace("_", "").replace("-", ""): v for k, v in dictionary.items()} SANITIZED_AVAILABLE_DATASETS = sanitize(AVAILABLE_DATASETS) -def exist_dataset(dataset): +def exist_dataset(dataset) -> bool: + """ + Check if dataset is available in the openQDC datasets. + """ if dataset not in sanitize(AVAILABLE_DATASETS): logger.error(f"{dataset} is not available. Please open an issue on Github for the team to look into it.") return False return True -def format_entry(empty_dataset): +def format_entry(empty_dataset, max_num_to_display: int = 6): + """ + Format the entry for the table. + max_num_to_display: int = 6, + Maximum number of energy methods to display. Used to keep the table format + readable in case of datasets with many energy methods. [ex. MultiXQM9] + """ energy_methods = [str(x) for x in empty_dataset.__energy_methods__] - max_num_to_display = 6 + if len(energy_methods) > 6: entry = ",".join(energy_methods[:max_num_to_display]) + "..." else: @@ -48,7 +60,7 @@ def download( overwrite: Annotated[ bool, typer.Option( - help="Whether to overwrite or force the re-download of the datasets.", + help="Whether to force the re-download of the datasets and overwrite the current cached dataset.", ), ] = False, cache_dir: Annotated[ @@ -60,13 +72,14 @@ def download( as_zarr: Annotated[ bool, typer.Option( - help="Whether to overwrite or force the re-download of the datasets.", + help="Whether to use a zarr format for the datasets instead of memmap.", ), ] = False, gs: Annotated[ bool, typer.Option( - help="Whether to use gs to re-download of the datasets.", + help="Whether source to use for downloading. If True, Google Storage will be used." + + "Otherwise, AWS S3 will be used", ), ] = False, ): @@ -78,6 +91,7 @@ def download( """ if gs: os.environ["OPENQDC_DOWNLOAD_API"] = "gs" + for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)): if exist_dataset(dataset): ds = SANITIZED_AVAILABLE_DATASETS[dataset].no_init() @@ -93,7 +107,7 @@ def download( @app.command() def datasets(): """ - Print a table of the available openQDC datasets and some informations. + Print a formatted table of the available openQDC datasets and some informations. """ table = PrettyTable(["Name", "Type of Energy", "Forces", "Level of theory"]) for dataset in AVAILABLE_DATASETS: @@ -118,7 +132,7 @@ def fetch( overwrite: Annotated[ bool, typer.Option( - help="Whether to overwrite or force the re-download of the files.", + help="Whether to overwrite or force the re-download of the raw files.", ), ] = False, cache_dir: Annotated[ @@ -129,17 +143,14 @@ def fetch( ] = None, ): """ - Download the raw datasets files from the main openQDC hub. - overwrite: bool = False, - If True, the files will be re-downloaded and overwritten. - cache_dir: Optional[str] = None, - Path to the cache. If not provided, the default cache directory will be used. - Special case: if the dataset is "all", "potential", "interaction". - all: all available datasets will be downloaded. - potential: all the potential datasets will be downloaded - interaction: all the interaction datasets will be downloaded - Example: - openqdc fetch Spice + Download the raw datasets files from the main openQDC hub.\n + Special case: if the dataset is "all", "potential", "interaction".\n + all: all available datasets will be downloaded.\n + potential: all the potential datasets will be downloaded\n + interaction: all the interaction datasets will be downloaded\n\n + + Example:\n + openqdc fetch Spice """ if datasets[0].lower() == "all": dataset_names = list(sanitize(AVAILABLE_DATASETS).keys()) @@ -163,18 +174,27 @@ def preprocess( overwrite: Annotated[ bool, typer.Option( - help="Whether to overwrite or force the re-download of the datasets.", + help="Whether to overwrite the current cached datasets.", ), ] = True, upload: Annotated[ bool, typer.Option( - help="Whether to try the upload to the remote storage.", + help="Whether to attempt the upload to the remote storage. Must have write permissions.", + ), + ] = False, + as_zarr: Annotated[ + bool, + typer.Option( + help="Whether to preprocess as a zarr format or a memmap format.", ), ] = False, ): """ Preprocess a raw dataset (previously fetched) into a openqdc dataset and optionally push it to remote. + + Example: + openqdc preprocess Spice QMugs """ for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)): if exist_dataset(dataset): @@ -192,7 +212,7 @@ def upload( overwrite: Annotated[ bool, typer.Option( - help="Whether to overwrite or force the re-download of the datasets.", + help="Whether to overwrite the remote files if they are present.", ), ] = True, as_zarr: Annotated[ @@ -204,6 +224,9 @@ def upload( ): """ Upload a preprocessed dataset to the remote storage. + + Example: + openqdc upload Spice --overwrite """ for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)): if exist_dataset(dataset): @@ -216,23 +239,23 @@ def upload( @app.command() -def convert_to_zarr( +def convert( datasets: List[str], overwrite: Annotated[ bool, typer.Option( - help="Whether to overwrite or force the re-download of the datasets.", + help="Whether to overwrite the current zarr cached datasets.", ), ] = False, download: Annotated[ bool, typer.Option( - help="Whether to force the re-download of the datasets.", + help="Whether to force the re-download of the memmap datasets.", ), ] = False, ): """ - Convert a preprocessed dataset to the zarr file format. + Convert a preprocessed dataset from a memmap dataset to a zarr dataset. """ import os from os.path import join as p_join @@ -243,6 +266,9 @@ def convert_to_zarr( from openqdc.utils.io import load_pkl def silent_remove(filename): + """ + Zarr zip files are currently not overwritable. This function is used to remove the file if it exists. + """ try: os.remove(filename) except OSError: @@ -305,7 +331,7 @@ def silent_remove(filename): @app.command() -def show_cache(): +def cache(): """ Get the current local cache path of openQDC """ diff --git a/openqdc/datasets/base.py b/openqdc/datasets/base.py index 1c582e0..3fc51fe 100644 --- a/openqdc/datasets/base.py +++ b/openqdc/datasets/base.py @@ -361,7 +361,7 @@ def set_array_format(self, format: str): def read_raw_entries(self): raise NotImplementedError - def collate_list(self, list_entries): + def collate_list(self, list_entries: List[Dict]): # concatenate entries res = {key: np.concatenate([r[key] for r in list_entries if r is not None], axis=0) for key in list_entries[0]} diff --git a/openqdc/datasets/dataset_structure.py b/openqdc/datasets/dataset_structure.py index e072afd..2e2a318 100644 --- a/openqdc/datasets/dataset_structure.py +++ b/openqdc/datasets/dataset_structure.py @@ -1,7 +1,7 @@ import pickle as pkl from abc import ABC, abstractmethod from os.path import join as p_join -from typing import List, Optional +from typing import Callable, List, Optional import numpy as np import zarr @@ -23,7 +23,10 @@ def ext(self): @property @abstractmethod - def load_fn(self): + def load_fn(self) -> Callable: + """ + Function to use for loading the data. + """ raise NotImplementedError def add_extension(self, filename): diff --git a/openqdc/datasets/io.py b/openqdc/datasets/io.py index 1e621f7..ed214fb 100644 --- a/openqdc/datasets/io.py +++ b/openqdc/datasets/io.py @@ -17,6 +17,8 @@ def try_retrieve(obj, callable, default=None): class FromFileDataset(BaseDataset, ABC): + """Abstract class for datasets that read from a common format file like xzy, netcdf, gro, hdf5, etc.""" + def __init__( self, path: List[str], @@ -35,12 +37,30 @@ def __init__( }, ): """ - Create a dataset from a xyz file. + Create a dataset from a list of files. Parameters ---------- path : List[str] The path to the file or a list of paths. + dataset_name : Optional[str], optional + The name of the dataset, by default None. + energy_type : Optional[str], optional + The type of isolated atom energy by default "regression". + Supported types: ["formation", "regression", "null", None] + energy_unit + Energy unit of the dataset. Default is "hartree". + distance_unit + Distance unit of the dataset. Default is "ang". + level_of_theory: Optional[QmMethod, str] + The level of theory of the dataset. + Used if energy_type is "formation" to fetch the correct isolated atom energies. + transform, optional + transformation to apply to the __getitem__ calls + regressor_kwargs + Dictionary of keyword arguments to pass to the regressor. + Default: {"solver_type": "linear", "sub_sample": None, "stride": 1} + solver_type can be one of ["linear", "ridge"] """ self.path = [path] if isinstance(path, str) else path self.__name__ = self.__class__.__name__ if dataset_name is None else dataset_name @@ -62,29 +82,19 @@ def __init__( self.set_array_format(array_format) self._post_init(True, energy_unit, distance_unit) - def __str__(self): - return self.__name__.lower() - - def __repr__(self): - return str(self) - @abstractmethod def read_as_atoms(self, path: str) -> List[Atoms]: """ - Method that reads a path and return a list of Atoms objects. + Method that reads a file and return a list of Atoms objects. + path : str + The path to the file. """ raise NotImplementedError - def collate_list(self, list_entries): - res = {key: np.concatenate([r[key] for r in list_entries if r is not None], axis=0) for key in list_entries[0]} - csum = np.cumsum(res.get("n_atoms")) - x = np.zeros((csum.shape[0], 2), dtype=np.int32) - x[1:, 0], x[:, 1] = csum[:-1], csum - res["position_idx_range"] = x - - return res - - def read_raw_entries(self): + def read_raw_entries(self) -> List[dict]: + """ + Process the files and return a list of data objects. + """ entries_list = [] for path in self.path: for entry in self.read_as_atoms(path): @@ -96,6 +106,11 @@ def _read_and_preprocess(self): self.data = self.collate_list(entries_list) def _convert_to_record(self, obj: Atoms): + """ + Convert an Atoms object to a record for the openQDC dataset processing. + obj : Atoms + The ase.Atoms object to convert + """ name = obj.info.get("name", None) subset = obj.info.get("subset", str(self)) positions = obj.positions @@ -116,8 +131,18 @@ def _convert_to_record(self, obj: Atoms): n_atoms=np.array([len(positions)], dtype=np.int32), ) + def __str__(self): + return self.__name__.lower() + + def __repr__(self): + return str(self) + class XYZDataset(FromFileDataset): + """ + Baseclass to read datasets from xyz and extxyz files. + """ + def read_as_atoms(self, path): from ase.io import iread diff --git a/openqdc/datasets/potential/alchemy.py b/openqdc/datasets/potential/alchemy.py index 969d13d..53d7a3d 100644 --- a/openqdc/datasets/potential/alchemy.py +++ b/openqdc/datasets/potential/alchemy.py @@ -44,6 +44,11 @@ def read_mol(file, energy): class Alchemy(BaseDataset): + """ + https://alchemy.tencent.com/ + https://arxiv.org/abs/1906.09427 + """ + __name__ = "alchemy" __energy_methods__ = [ diff --git a/openqdc/datasets/potential/proteinfragments.py b/openqdc/datasets/potential/proteinfragments.py index 4f5ea07..d26bab8 100644 --- a/openqdc/datasets/potential/proteinfragments.py +++ b/openqdc/datasets/potential/proteinfragments.py @@ -90,7 +90,7 @@ def _unpack_data_tuple(self, data): # graphs is smiles class ProteinFragments(BaseDataset): - """ """ + """https://www.science.org/doi/10.1126/sciadv.adn4397""" __name__ = "proteinfragments" @@ -134,6 +134,10 @@ def read_raw_entries(self): class MDDataset(ProteinFragments): + """ + Part of the proteinfragments dataset that is generated from the molecular dynamics with their model. + """ + __name__ = "mddataset" __links__ = { diff --git a/openqdc/datasets/potential/qmx.py b/openqdc/datasets/potential/qmx.py index 4b682a9..3719118 100644 --- a/openqdc/datasets/potential/qmx.py +++ b/openqdc/datasets/potential/qmx.py @@ -1,4 +1,5 @@ import os +from abc import ABC from os.path import join as p_join import datamol as dm @@ -32,21 +33,9 @@ def extract_ani2_entries(properties): return res -class QMX(BaseDataset): +class QMX(ABC, BaseDataset): """ - The ANI-1 dataset is a collection of 22 x 10^6 structural conformations from 57,000 distinct small - organic molecules with energy labels calculated using DFT. The molecules - contain 4 distinct atoms, C, N, O and H. - - Usage - ```python - from openqdc.datasets import ANI1 - dataset = ANI1() - ``` - - References: - - ANI-1: https://www.nature.com/articles/sdata2017193 - - Github: https://github.com/aiqm/ANI1x_datasets + QMX dataset base abstract class """ __name__ = "qm9" @@ -335,3 +324,15 @@ class QM9(QMX): "WB97X-D:def2-svp", "WB97X-D:def2-tzvp", ] + + __energy_methods__ = [ + PotentialMethod.NONE, # "wb97x/6-31g(d)" + PotentialMethod.NONE, + PotentialMethod.NONE, + PotentialMethod.NONE, + PotentialMethod.NONE, + PotentialMethod.NONE, + PotentialMethod.NONE, + PotentialMethod.NONE, + PotentialMethod.NONE, + ] diff --git a/openqdc/datasets/potential/vqm24.py b/openqdc/datasets/potential/vqm24.py index 4d54d69..d72b6b3 100644 --- a/openqdc/datasets/potential/vqm24.py +++ b/openqdc/datasets/potential/vqm24.py @@ -40,7 +40,7 @@ def read_npz_entry(raw_path): # graphs is smiles class VQM24(BaseDataset): - """ """ + """https://arxiv.org/abs/2405.05961""" __name__ = "vqm24" diff --git a/openqdc/datasets/statistics.py b/openqdc/datasets/statistics.py index 7197c2c..dfd99ea 100644 --- a/openqdc/datasets/statistics.py +++ b/openqdc/datasets/statistics.py @@ -2,7 +2,7 @@ from copy import deepcopy from dataclasses import asdict, dataclass from os.path import join as p_join -from typing import Optional +from typing import Callable, Optional import numpy as np from loguru import logger @@ -17,9 +17,15 @@ class StatisticsResults: """ def to_dict(self): + """ + Convert the class to a dictionary + """ return asdict(self) - def transform(self, func): + def transform(self, func: Callable): + """ + Apply a function to all the attributes of the class + """ for k, v in self.to_dict().items(): if v is not None: setattr(self, k, func(v)) @@ -55,6 +61,14 @@ class StatisticManager: """ def __init__(self, dataset, recompute: bool = False, *statistic_calculators: "AbstractStatsCalculator"): + """ + dataset : openqdc.datasets.base.BaseDataset + The dataset object to compute the statistics + recompute : bool, default = False + Flag to recompute the statistics + *statistic_calculators : AbstractStatsCalculator + statistic calculators to run + """ self._state = {} self._results = {} self._statistic_calculators = [ @@ -120,7 +134,7 @@ class AbstractStatsCalculator(ABC): """ Abstract class that defines the interface for all the calculators object and the methods to - compute the statistics + compute the statistics. """ # State Dependencies of the calculator to skip part of the calculation @@ -140,6 +154,28 @@ def __init__( atom_charges: Optional[np.ndarray] = None, forces: Optional[np.ndarray] = None, ): + """ + name : str + Name of the dataset for saving and loading. + energy_type : str, default = None + Type of the energy for the computation of the statistics. Used for loading and saving. + force_recompute : bool, default = False + Flag to force the recomputation of the statistics + energies : np.ndarray, default = None + Energies of the dataset + n_atoms : np.ndarray, default = None + Number of atoms in the dataset + atom_species : np.ndarray, default = None + Atomic species of the dataset + position_idx_range : np.ndarray, default = None + Position index range of the dataset + e0_matrix : np.ndarray, default = None + Isolated atom energies matrix of the dataset + atom_charges : np.ndarray, default = None + Atomic charges of the dataset + forces : np.ndarray, default = None + Forces of the dataset + """ self.name = name self.energy_type = energy_type self.force_recompute = force_recompute @@ -173,7 +209,7 @@ def root(self): @classmethod def from_openqdc_dataset(cls, dataset, recompute: bool = False): """ - Create a calculator object from a dataset object + Create a calculator object from a dataset object. """ obj = cls( name=dataset.__name__, @@ -203,7 +239,6 @@ def save_statistics(self) -> None: """ Save statistics file to the dataset folder as a pkl file """ - print(self.preprocess_path) save_pkl(self.result, self.preprocess_path) def attempt_load(self) -> bool: diff --git a/openqdc/utils/download_api.py b/openqdc/utils/download_api.py index 2cbf442..c48b4b6 100644 --- a/openqdc/utils/download_api.py +++ b/openqdc/utils/download_api.py @@ -29,7 +29,7 @@ @dataclass class FileSystem: """ - A class to handle file system operations + A basic class to handle file system operations """ public_endpoint: Optional[AbstractFileSystem] = None @@ -38,22 +38,31 @@ class FileSystem: endpoint_url = "https://874f02b9d981bd6c279e979c0d91c4b4.r2.cloudflarestorage.com" def __init__(self): - load_dotenv() + load_dotenv() # load environment variables from .env self.KEY = os.getenv("CLOUDFARE_KEY", None) self.SECRET = os.getenv("CLOUDFARE_SECRET", None) @property def public(self): + """ + Return the public remote filesystem with read permission + """ self.connect() return self.public_endpoint @property def private(self): + """ + Return the private remote filesystem with write permission + """ self.connect() return self.private_endpoint @property def local(self): + """ + Return the local filesystem + """ return self.local_endpoint @property @@ -65,7 +74,7 @@ def is_connected(self): def connect(self): """ - Attempt connection to the public and private endpoints + Attempt connection to the public and private remote endpoints """ if not self.is_connected: with warnings.catch_warnings(): diff --git a/openqdc/utils/package_utils.py b/openqdc/utils/package_utils.py index 990f6cb..e1381da 100644 --- a/openqdc/utils/package_utils.py +++ b/openqdc/utils/package_utils.py @@ -1,3 +1,4 @@ +# from openFF package import importlib from functools import wraps from typing import Any, Callable, TypeVar diff --git a/openqdc/utils/regressor.py b/openqdc/utils/regressor.py index 1d3e50a..a38d6c2 100644 --- a/openqdc/utils/regressor.py +++ b/openqdc/utils/regressor.py @@ -88,6 +88,10 @@ def __init__( @classmethod def from_openqdc_dataset(cls, dataset, *args, **kwargs): + """ + Initialize the regressor from an openqdc dataset. + *args and and **kwargs are passed to the __init__ method and depends on the specific regressor. + """ energies = dataset.data["energies"] position_idx_range = dataset.data["position_idx_range"] atomic_numbers = dataset.data["atomic_inputs"][:, 0].astype("int32") @@ -137,6 +141,9 @@ def _prepare_inputs(self) -> Tuple[np.ndarray, np.ndarray]: self.y = B def solve(self): + """ + Solve the regression problem and return the predicted isolated energies and the estimated uncertainty. + """ logger.info(f"Solving regression with {self.solver}.") E0_list, cov_list = [], [] for energy_idx in range(self.y.shape[1]): @@ -157,6 +164,11 @@ def __call__(self): def atom_standardization(X, y): + """ + Standardize the energies and the atom counts. + This will make the calculated uncertainty more + meaningful. + """ X_norm = X.sum() X = X / X_norm y = y / X_norm @@ -165,6 +177,11 @@ def atom_standardization(X, y): class LinearSolver(Solver): + """ + Linear regression solver. + No Uncertainty associated as it is quite small. + """ + _regr_str = "LinearRegression" @staticmethod @@ -175,6 +192,10 @@ def solve(X, y): class RidgeSolver(Solver): + """ + Ridge regression solver. + """ + _regr_str = "RidgeRegression" @staticmethod diff --git a/openqdc/utils/units.py b/openqdc/utils/units.py index d8613a5..d7fa834 100644 --- a/openqdc/utils/units.py +++ b/openqdc/utils/units.py @@ -40,7 +40,10 @@ class EnergyTypeConversion(ConversionEnum, StrEnum): MEV = "mev" RYD = "ryd" - def to(self, energy: "EnergyTypeConversion"): + def to(self, energy: "EnergyTypeConversion") -> Callable[[float], float]: + """ + Return a callable to convert the energy to the desired units. + """ return get_conversion(str(self), str(energy)) @@ -54,7 +57,10 @@ class DistanceTypeConversion(ConversionEnum, StrEnum): NM = "nm" BOHR = "bohr" - def to(self, distance: "DistanceTypeConversion", fraction: bool = False): + def to(self, distance: "DistanceTypeConversion", fraction: bool = False) -> Callable[[float], float]: + """ + Return a callable to convert the distance to the desired units. + """ return get_conversion(str(self), str(distance)) if not fraction else get_conversion(str(distance), str(self)) @@ -91,7 +97,10 @@ def __init__(self, energy: EnergyTypeConversion, distance: DistanceTypeConversio def __str__(self): return f"{self.energy}/{self.distance}" - def to(self, energy: EnergyTypeConversion, distance: DistanceTypeConversion): + def to(self, energy: EnergyTypeConversion, distance: DistanceTypeConversion) -> Callable[[float], float]: + """ + Return a callable to convert the force to the desired units. + """ return lambda x: self.distance.to(distance, fraction=True)(self.energy.to(energy)(x)) @@ -133,7 +142,14 @@ def __call__(self, x): return self.fn(x) -def get_conversion(in_unit: str, out_unit: str): +def get_conversion(in_unit: str, out_unit: str) -> Callable[[float], float]: + """ + Utility function to get the conversion function between two units. + in_unit : str + The input unit + out_unit : str + The output unit + """ name = "convert_" + in_unit.lower().strip() + "_to_" + out_unit.lower().strip() if in_unit.lower().strip() == out_unit.lower().strip(): return lambda x: x @@ -142,6 +158,8 @@ def get_conversion(in_unit: str, out_unit: str): return CONVERSION_REGISTRY[name] +# Conversion definitions + # ev conversion Conversion("ev", "kcal/mol", lambda x: x * 23.0605) Conversion("ev", "hartree", lambda x: x * 0.0367493)