diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2c34132..1f04d36 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -53,5 +53,5 @@ jobs: - name: Run tests run: python -m pytest - - name: Test building the doc - run: mkdocs build + #- name: Test building the doc + # run: mkdocs build diff --git a/mkdocs.yml b/mkdocs.yml index c1be218..6db3178 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -12,8 +12,8 @@ docs_dir: "docs" nav: - Overview: index.md - Available Datasets: datasets.md - - Tutorials: - - Really hard example: tutorials/usage.ipynb + #- Tutorials: + # #- Really hard example: tutorials/usage.ipynb - API: - Datasets: API/available_datasets.md - Isolated Atoms Energies: API/isolated_atom_energies.md diff --git a/openqdc/cli.py b/openqdc/cli.py index ee86fb1..faae4ce 100644 --- a/openqdc/cli.py +++ b/openqdc/cli.py @@ -7,8 +7,11 @@ from typing_extensions import Annotated from openqdc.datasets import COMMON_MAP_POTENTIALS # noqa -from openqdc.datasets import AVAILABLE_DATASETS, AVAILABLE_POTENTIAL_DATASETS -from openqdc.raws.config_factory import DataConfigFactory, DataDownloader +from openqdc.datasets import ( + AVAILABLE_DATASETS, + AVAILABLE_INTERACTION_DATASETS, + AVAILABLE_POTENTIAL_DATASETS, +) app = typer.Typer(help="OpenQDC CLI") @@ -83,22 +86,49 @@ def datasets(): @app.command() -def fetch(datasets: List[str]): +def fetch( + datasets: List[str], + overwrite: Annotated[ + bool, + typer.Option( + help="Whether to overwrite or force the re-download of the files.", + ), + ] = False, + cache_dir: Annotated[ + Optional[str], + typer.Option( + help="Path to the cache. If not provided, the default cache directory (.cache/openqdc/) will be used.", + ), + ] = None, +): """ Download the raw datasets files from the main openQDC hub. - Special case: if the dataset is "all", all available datasets will be downloaded. - + overwrite: bool = False, + If True, the files will be re-downloaded and overwritten. + cache_dir: Optional[str] = None, + Path to the cache. If not provided, the default cache directory will be used. + Special case: if the dataset is "all", "potential", "interaction". + all: all available datasets will be downloaded. + potential: all the potential datasets will be downloaded + interaction: all the interaction datasets will be downloaded Example: openqdc fetch Spice """ - if datasets[0] == "all": - dataset_names = DataConfigFactory.available_datasets + if datasets[0].lower() == "all": + dataset_names = AVAILABLE_DATASETS + elif datasets[0].lower() == "potential": + dataset_names = AVAILABLE_POTENTIAL_DATASETS + elif datasets[0].lower() == "interaction": + dataset_names = AVAILABLE_INTERACTION_DATASETS else: dataset_names = datasets - for dataset_name in dataset_names: - dd = DataDownloader() - dd.from_name(dataset_name) + for dataset in list(map(lambda x: x.lower().replace("_", ""), dataset_names)): + if exist_dataset(dataset): + try: + AVAILABLE_DATASETS[dataset].fetch(cache_dir, overwrite) + except Exception as e: + logger.error(f"Something unexpected happended while fetching {dataset}: {repr(e)}") @app.command() @@ -128,8 +158,6 @@ def preprocess( except Exception as e: logger.error(f"Error while preprocessing {dataset}. {e}. Did you fetch the dataset first?") raise e - else: - logger.warning(f"{dataset} not found.") if __name__ == "__main__": diff --git a/openqdc/datasets/base.py b/openqdc/datasets/base.py index 9648376..49a9b5c 100644 --- a/openqdc/datasets/base.py +++ b/openqdc/datasets/base.py @@ -82,6 +82,7 @@ class BaseDataset(DatasetPropertyMixIn): __distance_unit__ = "ang" __forces_unit__ = "hartree/ang" __average_nb_atoms__ = None + __links__ = {} def __init__( self, @@ -128,6 +129,7 @@ def __init__( set_cache_dir(cache_dir) # self._init_lambda_fn() self.data = None + self._original_unit = self.__energy_unit__ self.recompute_statistics = recompute_statistics self.regressor_kwargs = regressor_kwargs self.transform = transform @@ -145,6 +147,17 @@ def _init_lambda_fn(self): self._fn_distance = lambda x: x self._fn_forces = lambda x: x + @property + def config(self): + assert len(self.__links__) > 0, "No links provided for fetching" + return dict(dataset_name=self.__name__, links=self.__links__) + + @classmethod + def fetch(cls, cache_path: Optional[str] = None, overwrite: bool = False) -> None: + from openqdc.utils.download_api import DataDownloader + + DataDownloader(cache_path, overwrite).from_config(cls.no_init().config) + def _post_init( self, overwrite_local_cache: bool = False, @@ -256,6 +269,10 @@ def pkl_data_keys(self): def pkl_data_types(self): return {"name": str, "subset": str, "n_atoms": np.int32} + @property + def atom_energies(self): + return self._e0s_dispatcher + @property def data_types(self): return { @@ -287,7 +304,11 @@ def _set_units(self, en, ds): def _set_isolated_atom_energies(self): if self.__energy_methods__ is None: logger.error("No energy methods defined for this dataset.") - f = get_conversion("hartree", self.__energy_unit__) + if self.energy_type == "formation": + f = get_conversion("hartree", self.__energy_unit__) + else: + # regression are calculated on the original unit of the dataset + f = get_conversion(self._original_unit, self.__energy_unit__) self.__isolated_atom_energies__ = f(self.e0s_dispatcher.e0s_matrix) def convert_energy(self, x): diff --git a/openqdc/datasets/energies.py b/openqdc/datasets/energies.py index 60af6d5..3a19233 100644 --- a/openqdc/datasets/energies.py +++ b/openqdc/datasets/energies.py @@ -1,10 +1,13 @@ from abc import ABC, abstractmethod +from dataclasses import dataclass, field from os.path import join as p_join +from typing import Dict, Union import numpy as np from loguru import logger from openqdc.methods.enums import PotentialMethod +from openqdc.utils.constants import ATOM_SYMBOLS, ATOMIC_NUMBERS from openqdc.utils.io import load_pkl, save_pkl from openqdc.utils.regressor import Regressor @@ -41,6 +44,56 @@ def dispatch_factory(data, **kwargs) -> "IsolatedEnergyInterface": return NullEnergy(data, **kwargs) +@dataclass(frozen=False, eq=True) +class AtomSpecies: + """ + Structure that defines a tuple of chemical specie and charge + and provide hash and automatic conversion from atom number to + checmical symbol + """ + + symbol: Union[str, int] + charge: int = 0 + + def __post_init__(self): + if not isinstance(self.symbol, str): + self.symbol = ATOM_SYMBOLS[self.symbol] + self.number = ATOMIC_NUMBERS[self.symbol] + + def __hash__(self): + return hash((self.symbol, self.charge)) + + def __eq__(self, other): + if not isinstance(other, AtomSpecies): + symbol, charge = other[0], other[1] + other = AtomSpecies(symbol=symbol, charge=charge) + return (self.number, self.charge) == (other.number, other.charge) + + +@dataclass +class AtomEnergy: + """ + Datastructure to store isolated atom energies + and the std deviation associated to the value. + By default the std will be 1 if no value was calculated + or not available (formation energy case) + """ + + mean: np.array + std: np.array = field(default_factory=lambda: np.array([1], dtype=np.float32)) + + def __post_init__(self): + if not isinstance(self.mean, np.ndarray): + self.mean = np.array([self.mean], dtype=np.float32) + + def append(self, other: "AtomEnergy"): + """ + Append the mean and std of another atom energy + """ + self.mean = np.append(self.mean, other.mean) + self.std = np.append(self.std, other.std) + + class AtomEnergies: """ Manager class for interface with the isolated atom energies classes @@ -71,6 +124,42 @@ def e0s_matrix(self) -> np.ndarray: """ return self.factory.e0_matrix + @property + def e0s_dict(self) -> Dict[AtomSpecies, AtomEnergy]: + """ + Return the isolated atom energies dictionary + """ + return self.factory.e0_dict + + def __str__(self): + return f"Atoms: { list(set(map(lambda x : x.symbol, self.e0s_dict.keys())))}" + + def __repr__(self): + return str(self) + + def __getitem__(self, item: AtomSpecies) -> AtomEnergy: + """ + Retrieve a key from the isolated atom dictionary. + Item can be written as tuple(Symbol, charge), + tuple(Chemical number, charge). If no charge is passed, + it will be automatically set to 0. + Examples: + AtomEnergies[6], AtomEnergies[6,1], + AtomEnergies["C",1], AtomEnergies[(6,1)] + AtomEnergies[("C,1)] + """ + try: + atom, charge = item[0], item[1] + except TypeError: + atom = item + charge = 0 + except IndexError: + atom = item[0] + charge = 0 + if not isinstance(atom, str): + atom = ATOM_SYMBOLS[atom] + return self.e0s_dict[(atom, charge)] + class IsolatedEnergyInterface(ABC): """ @@ -78,8 +167,6 @@ class IsolatedEnergyInterface(ABC): different implementation of an isolated atom energy value """ - _e0_matrixs = [] - def __init__(self, data, **kwargs): """ Parameters @@ -93,7 +180,8 @@ def __init__(self, data, **kwargs): selected energy class. Mostly used for regression to pass the regressor_kwargs. """ - + self._e0_matrixs = [] + self._e0_dict = None self.kwargs = kwargs self.data = data self._post_init() @@ -120,27 +208,61 @@ def e0_matrix(self) -> np.ndarray: """ return np.array(self._e0_matrixs) + @property + def e0_dict(self) -> Dict: + """ + Return the isolated atom energies dict + """ + + return self._e0s_dict + def __str__(self) -> str: return self.__class__.__name__.lower() -class NullEnergy(IsolatedEnergyInterface): +class PhysicalEnergy(IsolatedEnergyInterface): """ - Class that returns a null (zeros) matrix for the isolated atom energies in case - of no energies are available. + Class that returns a physical (SE,DFT,etc) isolated atom energies. """ + def _assembly_e0_dict(self): + datum = {} + for method in self.data.__energy_methods__: + for key, values in method.atom_energies_dict.items(): + atm = AtomSpecies(*key) + ens = AtomEnergy(values) + if atm not in datum: + datum[atm] = ens + else: + datum[atm].append(ens) + self._e0s_dict = datum + def _post_init(self): - self._e0_matrixs = [PotentialMethod.NONE.atom_energies_matrix for _ in range(len(self.data.energy_methods))] + self._e0_matrixs = [energy_method.atom_energies_matrix for energy_method in self.data.__energy_methods__] + self._assembly_e0_dict() -class PhysicalEnergy(IsolatedEnergyInterface): +class NullEnergy(IsolatedEnergyInterface): """ - Class that returns a physical (SE,DFT,etc) isolated atom energies. + Class that returns a null (zeros) matrix for the isolated atom energies in case + of no energies are available. """ + def _assembly_e0_dict(self): + datum = {} + for _ in self.data.__energy_methods__: + for key, values in PotentialMethod.NONE.atom_energies_dict.items(): + atm = AtomSpecies(*key) + ens = AtomEnergy(values) + if atm not in datum: + datum[atm] = ens + else: + datum[atm].append(ens) + self._e0s_dict = datum + def _post_init(self): - self._e0_matrixs = [energy_method.atom_energies_matrix for energy_method in self.data.__energy_methods__] + self._e0_matrixs = [PotentialMethod.NONE.atom_energies_matrix for _ in range(len(self.data.energy_methods))] + self._assembly_e0_dict() class RegressionEnergy(IsolatedEnergyInterface): @@ -175,7 +297,9 @@ def _set_lin_atom_species_dict(self, E0s, covs) -> None: """ atomic_energies_dict = {} for i, z in enumerate(self.regressor.numbers): - atomic_energies_dict[z] = E0s[i] + for charge in range(-10, 11): + atomic_energies_dict[AtomSpecies(z, charge)] = AtomEnergy(E0s[i], 1 if covs is None else covs[i]) + # atomic_energies_dict[z] = E0s[i] self._e0s_dict = atomic_energies_dict self.save_e0s() @@ -187,7 +311,9 @@ def _set_linear_e0s(self) -> None: new_e0s = [np.zeros((max(self.data.numbers) + 1, MAX_CHARGE_NUMBER)) for _ in range(len(self))] for z, e0 in self._e0s_dict.items(): for i in range(len(self)): - new_e0s[i][z, :] = e0[i] + # new_e0s[i][z, :] = e0[i] + new_e0s[i][z.number, z.charge] = e0.mean[i] + # for atom_sp, values in self._e0_matrixs = new_e0s def save_e0s(self) -> None: diff --git a/openqdc/datasets/interaction/des.py b/openqdc/datasets/interaction/des.py index 0a7cc33..d90be07 100644 --- a/openqdc/datasets/interaction/des.py +++ b/openqdc/datasets/interaction/des.py @@ -147,6 +147,9 @@ class DES370K(BaseInteractionDataset, IDES): "sapt_exdisp_ss", "sapt_delta_HF", ] + __links__ = { + "DES370K.zip": "https://zenodo.org/record/5676266/files/DES370K.zip", + } @property def csv_path(self): @@ -232,6 +235,9 @@ class DES5M(DES370K): "sapt_exdisp_ss", "sapt_delta_HF", ] + __links__ = { + "DES5M.zip": "https://zenodo.org/records/5706002/files/DESS5M.zip?download=1", + } class DESS66(DES370K): @@ -252,6 +258,7 @@ class DESS66(DES370K): __name__ = "des_s66" __filename__ = "DESS66.csv" + __links__ = {"DESS66.zip": "https://zenodo.org/records/5676284/files/DESS66.zip?download=1"} def _create_subsets(self, **kwargs): return kwargs["row"]["system_name"] @@ -276,3 +283,4 @@ class DESS66x8(DESS66): __name__ = "des_s66x8" __filename__ = "DESS66x8.csv" + __links__ = {"DESS66x8.zip": "https://zenodo.org/records/5676284/files/DESS66x8.zip?download=1"} diff --git a/openqdc/datasets/interaction/l7.py b/openqdc/datasets/interaction/l7.py index 3f77b44..75a63cd 100644 --- a/openqdc/datasets/interaction/l7.py +++ b/openqdc/datasets/interaction/l7.py @@ -29,6 +29,10 @@ class L7(YamlDataset): InteractionMethod.LNO_CCSDT, # "LNO-CCSD(T)", InteractionMethod.FN_DMC, # "FN-DMC", ] + __links__ = { + "l7.yaml": "http://cuby4.molecular.cz/download_datasets/l7.yaml", + "geometries.tar.gz": "http://cuby4.molecular.cz/download_geometries/L7.tar", + } def _process_name(self, item): return item.geometry.split(":")[1] diff --git a/openqdc/datasets/interaction/metcalf.py b/openqdc/datasets/interaction/metcalf.py index 2d70746..faf5324 100644 --- a/openqdc/datasets/interaction/metcalf.py +++ b/openqdc/datasets/interaction/metcalf.py @@ -10,8 +10,8 @@ from openqdc.datasets.interaction.base import BaseInteractionDataset from openqdc.methods import InteractionMethod, InterEnergyType -from openqdc.raws.config_factory import decompress_tar_gz from openqdc.utils.constants import ATOM_TABLE +from openqdc.utils.download_api import decompress_tar_gz EXPECTED_TAR_FILES = { "train": [ @@ -125,6 +125,7 @@ class Metcalf(BaseInteractionDataset): "induction energy", "dispersion energy", ] + __links__ = {"model-data.tar.gz": "https://zenodo.org/records/10934211/files/model-data.tar?download=1"} def read_raw_entries(self) -> List[Dict]: # extract in folders diff --git a/openqdc/datasets/interaction/splinter.py b/openqdc/datasets/interaction/splinter.py index a793624..bda1012 100644 --- a/openqdc/datasets/interaction/splinter.py +++ b/openqdc/datasets/interaction/splinter.py @@ -92,6 +92,24 @@ class Splinter(BaseInteractionDataset): InterEnergyType.DISP, ] energy_target_names = [] + __links__ = { + "dimerpairs.0.tar.gz": "https://figshare.com/ndownloader/files/39449167", + "dimerpairs.1.tar.gz": "https://figshare.com/ndownloader/files/40271983", + "dimerpairs.2.tar.gz": "https://figshare.com/ndownloader/files/40271989", + "dimerpairs.3.tar.gz": "https://figshare.com/ndownloader/files/40272001", + "dimerpairs.4.tar.gz": "https://figshare.com/ndownloader/files/40272022", + "dimerpairs.5.tar.gz": "https://figshare.com/ndownloader/files/40552931", + "dimerpairs.6.tar.gz": "https://figshare.com/ndownloader/files/40272040", + "dimerpairs.7.tar.gz": "https://figshare.com/ndownloader/files/40272052", + "dimerpairs.8.tar.gz": "https://figshare.com/ndownloader/files/40272061", + "dimerpairs.9.tar.gz": "https://figshare.com/ndownloader/files/40272064", + "dimerpairs_nonstandard.tar.gz": "https://figshare.com/ndownloader/files/40272067", + "lig_interaction_sites.sdf": "https://figshare.com/ndownloader/files/40272070", + "lig_monomers.sdf": "https://figshare.com/ndownloader/files/40272073", + "prot_interaction_sites.sdf": "https://figshare.com/ndownloader/files/40272076", + "prot_monomers.sdf": "https://figshare.com/ndownloader/files/40272079", + "merge_monomers.py": "https://figshare.com/ndownloader/files/41807682", + } def read_raw_entries(self) -> List[Dict]: logger.info(f"Reading Splinter interaction data from {self.root}") diff --git a/openqdc/datasets/interaction/x40.py b/openqdc/datasets/interaction/x40.py index 32a3cbf..64da5d8 100644 --- a/openqdc/datasets/interaction/x40.py +++ b/openqdc/datasets/interaction/x40.py @@ -28,6 +28,10 @@ class X40(YamlDataset): InteractionMethod.DCCSDT_HA_TZ, # "dCCSD(T)/haTZ", InteractionMethod.MP2_5_CBS_ADZ, # "MP2.5/CBS(aDZ)", ] + __links__ = { + "x40.yaml": "http://cuby4.molecular.cz/download_datasets/x40.yaml", + "geometries.tar.gz": "http://cuby4.molecular.cz/download_geometries/X40.tar", + } def _process_name(self, item): return item.shortname diff --git a/openqdc/datasets/io.py b/openqdc/datasets/io.py index bf90ea5..cd8bfdb 100644 --- a/openqdc/datasets/io.py +++ b/openqdc/datasets/io.py @@ -47,6 +47,7 @@ def __init__( self.recompute_statistics = True self.refit_e0s = True self.energy_type = energy_type + self._original_unit = energy_unit self.__energy_unit__ = energy_unit self.__distance_unit__ = distance_unit self.__energy_methods__ = [PotentialMethod.NONE if not level_of_theory else level_of_theory] diff --git a/openqdc/datasets/potential/ani.py b/openqdc/datasets/potential/ani.py index dcf9fcd..70fc882 100644 --- a/openqdc/datasets/potential/ani.py +++ b/openqdc/datasets/potential/ani.py @@ -37,11 +37,17 @@ class ANI1(BaseDataset): __energy_unit__ = "hartree" __distance_unit__ = "bohr" __forces_unit__ = "hartree/bohr" + __links__ = {"ani1.hdf5.gz": "https://zenodo.org/record/3585840/files/214.hdf5.gz"} @property def root(self): return p_join(get_local_cache(), "ani") + @property + def config(self): + assert len(self.__links__) > 0, "No links provided for fetching" + return dict(dataset_name="ani", links=self.__links__) + def __smiles_converter__(self, x): """util function to convert string to smiles: useful if the smiles is encoded in a different format than its display format @@ -95,6 +101,7 @@ class ANI1CCX(ANI1): "TPNO-CCSD(T):cc-pVDZ Correlation Energy", ] force_target_names = [] + __links__ = {"ani1x.hdf5.gz": "https://zenodo.org/record/4081694/files/292.hdf5.gz"} def __smiles_converter__(self, x): """util function to convert string to smiles: useful if the smiles is @@ -152,6 +159,7 @@ class ANI1X(ANI1): ] __force_mask__ = [False, False, False, False, False, False, True, True] + __links__ = {"ani1ccx.hdf5.gz": "https://zenodo.org/record/4081692/files/293.hdf5.gz"} def convert_forces(self, x): return super().convert_forces(x) * 0.529177249 # correct the Dataset error diff --git a/openqdc/datasets/potential/gdml.py b/openqdc/datasets/potential/gdml.py index 92ccb3e..d58be18 100644 --- a/openqdc/datasets/potential/gdml.py +++ b/openqdc/datasets/potential/gdml.py @@ -56,6 +56,14 @@ class GDML(BaseDataset): __energy_unit__ = "kcal/mol" __distance_unit__ = "bohr" __forces_unit__ = "kcal/mol/bohr" + __links__ = { + "gdb7_9.hdf5.gz": "https://zenodo.org/record/3588361/files/208.hdf5.gz", + "gdb10_13.hdf5.gz": "https://zenodo.org/record/3588364/files/209.hdf5.gz", + "drugbank.hdf5.gz": "https://zenodo.org/record/3588361/files/207.hdf5.gz", + "tripeptides.hdf5.gz": "https://zenodo.org/record/3588368/files/211.hdf5.gz", + "ani_md.hdf5.gz": "https://zenodo.org/record/3588341/files/205.hdf5.gz", + "s66x8.hdf5.gz": "https://zenodo.org/record/3588367/files/210.hdf5.gz", + } def read_raw_entries(self): raw_path = p_join(self.root, "gdml.h5.gz") diff --git a/openqdc/datasets/potential/geom.py b/openqdc/datasets/potential/geom.py index 0b20a7c..d07a3d9 100644 --- a/openqdc/datasets/potential/geom.py +++ b/openqdc/datasets/potential/geom.py @@ -87,6 +87,7 @@ class GEOM(BaseDataset): force_target_names = [] partitions = ["qm9", "drugs"] + __links__ = {"rdkit_folder.tar.gz": "https://dataverse.harvard.edu/api/access/datafile/4327252"} def _read_raw_(self, partition): raw_path = p_join(self.root, "rdkit_folder") diff --git a/openqdc/datasets/potential/iso_17.py b/openqdc/datasets/potential/iso_17.py index 7fd7be9..9263015 100644 --- a/openqdc/datasets/potential/iso_17.py +++ b/openqdc/datasets/potential/iso_17.py @@ -42,6 +42,7 @@ class ISO17(BaseDataset): __energy_unit__ = "ev" __distance_unit__ = "bohr" # bohr __forces_unit__ = "ev/bohr" + __links__ = {"iso_17.hdf5.gz": "https://zenodo.org/record/3585907/files/216.hdf5.gz"} def __smiles_converter__(self, x): """util function to convert string to smiles: useful if the smiles is diff --git a/openqdc/datasets/potential/md22.py b/openqdc/datasets/potential/md22.py index 6697dd0..b997642 100644 --- a/openqdc/datasets/potential/md22.py +++ b/openqdc/datasets/potential/md22.py @@ -41,6 +41,18 @@ def create_path(filename, root): class MD22(RevMD17): __name__ = "md22" + __links__ = { + f"{x}.npz": f"http://www.quantum-machine.org/gdml/repo/datasets/md22_{x}.npz" + for x in [ + "Ac-Ala3-NHMe", + "DHA", + "stachyose", + "AT-AT", + "AT-AT-CG-CG", + "double-walled_nanotube", + "buckyball-catcher", + ] + } def read_raw_entries(self): entries_list = [] diff --git a/openqdc/datasets/potential/molecule3d.py b/openqdc/datasets/potential/molecule3d.py index e1e10fc..4fc28c7 100644 --- a/openqdc/datasets/potential/molecule3d.py +++ b/openqdc/datasets/potential/molecule3d.py @@ -88,6 +88,7 @@ class Molecule3D(BaseDataset): __energy_unit__ = "ev" # CALCULATED __distance_unit__ = "ang" __forces_unit__ = "ev/ang" + __links__ = {"molecule3d.zip": "https://drive.google.com/uc?id=1C_KRf8mX-gxny7kL9ACNCEV4ceu_fUGy"} energy_target_names = ["b3lyp/6-31g*.energy"] diff --git a/openqdc/datasets/potential/multixcqm9.py b/openqdc/datasets/potential/multixcqm9.py index 2bf4906..83263d7 100644 --- a/openqdc/datasets/potential/multixcqm9.py +++ b/openqdc/datasets/potential/multixcqm9.py @@ -522,6 +522,14 @@ class MultixcQM9(BaseDataset): __energy_unit__ = "ev" # to fix __distance_unit__ = "ang" # to fix __forces_unit__ = "ev/ang" # to fix + __links__ = { + "xyz.zip": "https://data.dtu.dk/ndownloader/files/35143624", + "xtb.zip": "https://data.dtu.dk/ndownloader/files/42444300", + "dzp.zip": "https://data.dtu.dk/ndownloader/files/42443925", + "tzp.zip": "https://data.dtu.dk/ndownloader/files/42444129", + "sz.zip": "https://data.dtu.dk/ndownloader/files/42441345", + "failed_indices.dat": "https://data.dtu.dk/ndownloader/files/37337677", + } def _read_molecules_energies(self): d = {"DZP": None, "TZP": None, "SZ": None, "XTB": None} diff --git a/openqdc/datasets/potential/nabladft.py b/openqdc/datasets/potential/nabladft.py index c10f108..4700ade 100644 --- a/openqdc/datasets/potential/nabladft.py +++ b/openqdc/datasets/potential/nabladft.py @@ -74,6 +74,7 @@ class NablaDFT(BaseDataset): __energy_unit__ = "hartree" __distance_unit__ = "bohr" __forces_unit__ = "hartree/bohr" + __links__ = {"nabladft.db": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_full.db"} @requires_package("nablaDFT") def read_raw_entries(self): diff --git a/openqdc/datasets/potential/orbnet_denali.py b/openqdc/datasets/potential/orbnet_denali.py index fb7476d..6a7c3f4 100644 --- a/openqdc/datasets/potential/orbnet_denali.py +++ b/openqdc/datasets/potential/orbnet_denali.py @@ -61,6 +61,10 @@ class OrbnetDenali(BaseDataset): __energy_unit__ = "hartree" __distance_unit__ = "ang" __forces_unit__ = "hartree/ang" + __links__ = { + "orbnet_denali.tar.gz": "https://figshare.com/ndownloader/files/28672287", + "orbnet_denali_targets.tar.gz": "https://figshare.com/ndownloader/files/28672248", + } def read_raw_entries(self): label_path = p_join(self.root, "denali_labels.csv") diff --git a/openqdc/datasets/potential/qm7x.py b/openqdc/datasets/potential/qm7x.py index 5357067..92689c5 100644 --- a/openqdc/datasets/potential/qm7x.py +++ b/openqdc/datasets/potential/qm7x.py @@ -66,6 +66,7 @@ class QM7X(BaseDataset): __energy_unit__ = "ev" __distance_unit__ = "ang" __forces_unit__ = "ev/ang" + __links__ = {f"{i}000.xz": f"https://zenodo.org/record/4288677/files/{i}000.xz" for i in range(1, 9)} def read_raw_entries(self): samples = [] diff --git a/openqdc/datasets/potential/qmugs.py b/openqdc/datasets/potential/qmugs.py index ceaeb11..7dd205e 100644 --- a/openqdc/datasets/potential/qmugs.py +++ b/openqdc/datasets/potential/qmugs.py @@ -57,6 +57,10 @@ class QMugs(BaseDataset): __energy_unit__ = "hartree" __distance_unit__ = "ang" __forces_unit__ = "hartree/ang" + __links__ = { + "summary.csv": "https://libdrive.ethz.ch/index.php/s/X5vOBNSITAG5vzM/download?path=%2F&files=summary.csv", + "structures.tar.gz": "https://libdrive.ethz.ch/index.php/s/X5vOBNSITAG5vzM/download?path=%2F&files=structures.tar.gz", # noqa + } energy_target_names = [ "GFN2:TOTAL_ENERGY", diff --git a/openqdc/datasets/potential/revmd17.py b/openqdc/datasets/potential/revmd17.py index aeb7865..613ce91 100644 --- a/openqdc/datasets/potential/revmd17.py +++ b/openqdc/datasets/potential/revmd17.py @@ -4,7 +4,7 @@ from openqdc.datasets.base import BaseDataset from openqdc.methods import PotentialMethod -from openqdc.raws.config_factory import decompress_tar_gz +from openqdc.utils.download_api import decompress_tar_gz trajectories = { "rmd17_aspirin": "CC(=O)OC1=CC=CC=C1C(=O)O", @@ -92,6 +92,7 @@ class RevMD17(BaseDataset): force_target_names = [ "PBE-TS Gradient", ] + __links__ = {"revmd17.zip": "https://figshare.com/ndownloader/articles/12672038/versions/3"} __energy_unit__ = "kcal/mol" __distance_unit__ = "ang" diff --git a/openqdc/datasets/potential/sn2_rxn.py b/openqdc/datasets/potential/sn2_rxn.py index 8e7bba9..0f58069 100644 --- a/openqdc/datasets/potential/sn2_rxn.py +++ b/openqdc/datasets/potential/sn2_rxn.py @@ -32,6 +32,7 @@ class SN2RXN(BaseDataset): __energy_unit__ = "ev" __distance_unit__ = "bohr" __forces_unit__ = "ev/bohr" + __links__ = {"sn2_rxn.hdf5.gz": "https://zenodo.org/records/2605341/files/sn2_reactions.npz"} energy_target_names = [ # TODO: We need to revalidate this to make sure that is not atomization energies. diff --git a/openqdc/datasets/potential/solvated_peptides.py b/openqdc/datasets/potential/solvated_peptides.py index 9e4677d..2637ca4 100644 --- a/openqdc/datasets/potential/solvated_peptides.py +++ b/openqdc/datasets/potential/solvated_peptides.py @@ -44,6 +44,7 @@ class SolvatedPeptides(BaseDataset): __energy_unit__ = "hartree" __distance_unit__ = "bohr" __forces_unit__ = "hartree/bohr" + __links__ = {"solvated_peptides.hdf5.gz": "https://zenodo.org/record/3585804/files/213.hdf5.gz"} def __smiles_converter__(self, x): """util function to convert string to smiles: useful if the smiles is @@ -52,7 +53,7 @@ def __smiles_converter__(self, x): return "_".join(x.decode("ascii").split("_")[:-1]) def read_raw_entries(self): - raw_path = p_join(self.root, "solvated_peptides.h5") + raw_path = p_join(self.root, "solvated_peptides.h5.gz") samples = read_qc_archive_h5(raw_path, "solvated_peptides", self.energy_target_names, self.force_target_names) return samples diff --git a/openqdc/datasets/potential/spice.py b/openqdc/datasets/potential/spice.py index c55ed78..62b550b 100644 --- a/openqdc/datasets/potential/spice.py +++ b/openqdc/datasets/potential/spice.py @@ -80,6 +80,7 @@ class Spice(BaseDataset): "SPICE PubChem Set 6 Single Points Dataset v1.2": "PubChem", "SPICE Ion Pairs Single Points Dataset v1.1": "Ion Pairs", } + __links__ = {"SPICE-1.1.4.hdf5": "https://zenodo.org/record/8222043/files/SPICE-1.1.4.hdf5"} def convert_forces(self, x): return (-1.0) * super().convert_forces(x) @@ -135,6 +136,7 @@ class SpiceV2(Spice): "SPICE PubChem Boron Silicon v1.0": "PubChem Boron Silicon", "SPICE Ion Pairs Single Points Dataset v1.2": "Ion Pairs", } + __links__ = {"spice-2.0.0.hdf5": "https://zenodo.org/records/10835749/files/SPICE-2.0.0.hdf5?download=1"} def read_raw_entries(self): raw_path = p_join(self.root, "spice-2.0.0.hdf5") diff --git a/openqdc/datasets/potential/tmqm.py b/openqdc/datasets/potential/tmqm.py index 4d5b856..1da6901 100644 --- a/openqdc/datasets/potential/tmqm.py +++ b/openqdc/datasets/potential/tmqm.py @@ -72,6 +72,10 @@ class TMQM(BaseDataset): __energy_unit__ = "hartree" __distance_unit__ = "ang" __forces_unit__ = "hartree/ang" + __links__ = { + x: f"https://raw.githubusercontent.com/bbskjelstad/tmqm/master/data/{x}" + for x in ["tmQM_X1.xyz.gz", "tmQM_X2.xyz.gz", "tmQM_y.csv", "Benchmark2_TPSSh_Opt.xyz"] + } def read_raw_entries(self): df = pd.read_csv(p_join(self.root, "tmQM_y.csv"), sep=";", usecols=["CSD_code", "Electronic_E"]) diff --git a/openqdc/datasets/potential/transition1x.py b/openqdc/datasets/potential/transition1x.py index 8153304..8b5b4bc 100644 --- a/openqdc/datasets/potential/transition1x.py +++ b/openqdc/datasets/potential/transition1x.py @@ -73,6 +73,7 @@ class Transition1X(BaseDataset): __energy_unit__ = "ev" __distance_unit__ = "ang" __forces_unit__ = "ev/ang" + __links__ = {"Transition1x.h5": "https://figshare.com/ndownloader/files/36035789"} def read_raw_entries(self): raw_path = p_join(self.root, "Transition1x.h5") diff --git a/openqdc/datasets/potential/waterclusters3_30.py b/openqdc/datasets/potential/waterclusters3_30.py index 7f3086f..e473353 100644 --- a/openqdc/datasets/potential/waterclusters3_30.py +++ b/openqdc/datasets/potential/waterclusters3_30.py @@ -76,6 +76,7 @@ class WaterClusters(BaseDataset): __energy_methods__ = [PotentialMethod.TTM2_1_F] # "ttm2.1-f" energy_target_names = ["TTM2.1-F Potential"] + __links__ = {"W3-W30_all_geoms_TTM2.1-F.zip": "https://drive.google.com/uc?id=18Y7OiZXSCTsHrQ83GCc4fyE_abbL6E_n"} def read_raw_entries(self): samples = [] diff --git a/openqdc/datasets/statistics.py b/openqdc/datasets/statistics.py index 2122271..d471387 100644 --- a/openqdc/datasets/statistics.py +++ b/openqdc/datasets/statistics.py @@ -208,6 +208,7 @@ def attempt_load(self) -> bool: """ try: self.result = load_pkl(self.preprocess_path) + logger.info(f"Statistics for {str(self)} loaded successfully") return True except FileNotFoundError: logger.warning(f"Statistics for {str(self)} not found. Computing...") diff --git a/openqdc/methods/enums.py b/openqdc/methods/enums.py index 50a7f1d..3cf5104 100644 --- a/openqdc/methods/enums.py +++ b/openqdc/methods/enums.py @@ -1,8 +1,10 @@ from enum import Enum from loguru import logger +from numpy import array, float32 from openqdc.methods.atom_energies import atom_energy_collection, to_e_matrix +from openqdc.utils.constants import ATOM_SYMBOLS class StrEnum(str, Enum): @@ -472,6 +474,13 @@ class PotentialMethod(QmMethod): # SPLIT FOR INTERACTIO ENERGIES AND FIX MD1 XLYP_TZP = Functional.XLYP, BasisSet.TZP NONE = Functional.NONE, BasisSet.NONE + def _build_default_dict(self): + e0_dict = {} + for SYMBOL in ATOM_SYMBOLS: + for CHARGE in range(-10, 11): + e0_dict[(SYMBOL, CHARGE)] = array([0], dtype=float32) + return e0_dict + @property def atom_energies_dict(self): """Get the atomization energy dictionary""" @@ -483,7 +492,7 @@ def atom_energies_dict(self): raise except: # noqa logger.info(f"No available atomization energy for the QM method {key}. All values are set to 0.") - + energies = self._build_default_dict() return energies diff --git a/openqdc/raws/__init__.py b/openqdc/raws/__init__.py deleted file mode 100644 index 5bda993..0000000 --- a/openqdc/raws/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .config_factory import DataConfigFactory, DataDownloader diff --git a/openqdc/raws/config_factory.py b/openqdc/raws/config_factory.py deleted file mode 100644 index aa3e5d8..0000000 --- a/openqdc/raws/config_factory.py +++ /dev/null @@ -1,420 +0,0 @@ -import gzip -import os -import shutil -import socket -import tarfile -import urllib.error -import urllib.request -import zipfile - -import fsspec -import gdown -import requests -import tqdm -from loguru import logger -from sklearn.utils import Bunch - -from openqdc.utils.io import get_local_cache - - -def download_url(url, local_filename): - """ - Download a file from a url to a local file. - Parameters - ---------- - url : str - URL to download from. - local_filename : str - Local path for destination. - """ - logger.info(f"Url: {url} File: {local_filename}") - if "drive.google.com" in url: - gdown.download(url, local_filename, quiet=False) - elif "raw.github" in url: - r = requests.get(url, allow_redirects=True) - with open(local_filename, "wb") as f: - f.write(r.content) - else: - r = requests.get(url, stream=True) - with fsspec.open(local_filename, "wb") as f: - for chunk in tqdm.tqdm(r.iter_content(chunk_size=16384)): - if chunk: - f.write(chunk) - - -def decompress_tar_gz(local_filename): - """ - Decompress a tar.gz file. - Parameters - ---------- - local_filename : str - Path to local file to decompress. - """ - parent = os.path.dirname(local_filename) - with tarfile.open(local_filename) as tar: - logger.info(f"Verifying archive extraction states: {local_filename}") - all_names = tar.getnames() - all_extracted = all([os.path.exists(os.path.join(parent, x)) for x in all_names]) - if not all_extracted: - logger.info(f"Extracting archive: {local_filename}") - tar.extractall(path=parent) - else: - logger.info(f"Archive already extracted: {local_filename}") - - -def decompress_zip(local_filename): - """ - Decompress a zip file. - Parameters - ---------- - local_filename : str - Path to local file to decompress. - """ - parent = os.path.dirname(local_filename) - - logger.info(f"Verifying archive extraction states: {local_filename}") - with zipfile.ZipFile(local_filename, "r") as zip_ref: - all_names = zip_ref.namelist() - all_extracted = all([os.path.exists(os.path.join(parent, x)) for x in all_names]) - if not all_extracted: - logger.info(f"Extracting archive: {local_filename}") - zip_ref.extractall(parent) - else: - logger.info(f"Archive already extracted: {local_filename}") - - -def decompress_gz(local_filename): - """ - Decompress a gz file. - Parameters - ---------- - local_filename : str - Path to local file to decompress. - """ - logger.info(f"Verifying archive extraction states: {local_filename}") - out_filename = local_filename.replace(".gz", "") - if out_filename.endswith("hdf5"): - out_filename = local_filename.replace("hdf5", "h5") - - all_extracted = os.path.exists(out_filename) - if not all_extracted: - logger.info(f"Extracting archive: {local_filename}") - with gzip.open(local_filename, "rb") as f_in, open(out_filename, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) - else: - logger.info(f"Archive already extracted: {local_filename}") - - -def fetch_file(url, local_filename, overwrite=False): - """ - Download a file from a url to a local file. Useful for big files. - Parameters - ---------- - url : str - URL to download from. - local_filename : str - Local file to save to. - overwrite : bool - Whether to overwrite existing files. - Returns - ------- - local_filename : str - Local file. - """ - try: - if os.path.exists(local_filename) and not overwrite: - logger.info("File already exists, skipping download") - else: - download_url(url, local_filename) - - # decompress archive if necessary - parent = os.path.dirname(local_filename) - if local_filename.endswith("tar.gz"): - decompress_tar_gz(local_filename) - - elif local_filename.endswith("zip"): - decompress_zip(local_filename) - - elif local_filename.endswith(".gz"): - decompress_gz(local_filename) - - elif local_filename.endswith("xz"): - logger.info(f"Extracting archive: {local_filename}") - os.system(f"cd {parent} && xz -d *.xz") - - else: - pass - - except (socket.gaierror, urllib.error.URLError) as err: - raise ConnectionError("Could not download {} due to {}".format(url, err)) - - return local_filename - - -class DataConfigFactory: - ani = dict( - dataset_name="ani", - links={ - "ani1.hdf5.gz": "https://zenodo.org/record/3585840/files/214.hdf5.gz", - "ani1x.hdf5.gz": "https://zenodo.org/record/4081694/files/292.hdf5.gz", - "ani1ccx.hdf5.gz": "https://zenodo.org/record/4081692/files/293.hdf5.gz", - }, - ) - - comp6 = dict( - dataset_name="comp6", - links={ - "gdb7_9.hdf5.gz": "https://zenodo.org/record/3588361/files/208.hdf5.gz", - "gdb10_13.hdf5.gz": "https://zenodo.org/record/3588364/files/209.hdf5.gz", - "drugbank.hdf5.gz": "https://zenodo.org/record/3588361/files/207.hdf5.gz", - "tripeptides.hdf5.gz": "https://zenodo.org/record/3588368/files/211.hdf5.gz", - "ani_md.hdf5.gz": "https://zenodo.org/record/3588341/files/205.hdf5.gz", - "s66x8.hdf5.gz": "https://zenodo.org/record/3588367/files/210.hdf5.gz", - }, - ) - - gdml = dict( - dataset_name="gdml", - links={"gdml.hdf5.gz": "https://zenodo.org/record/3585908/files/219.hdf5.gz"}, - ) - - solvated_peptides = dict( - dataset_name="solvated_peptides", - links={"solvated_peptides.hdf5.gz": "https://zenodo.org/record/3585804/files/213.hdf5.gz"}, - ) - - iso_17 = dict( - dataset_name="iso_17", - links={"iso_17.hdf5.gz": "https://zenodo.org/record/3585907/files/216.hdf5.gz"}, - ) - - sn2_rxn = dict( - dataset_name="sn2_rxn", - links={"sn2_rxn.hdf5.gz": "https://zenodo.org/records/2605341/files/sn2_reactions.npz"}, - ) - - # FROM: https://sites.uw.edu/wdbase/database-of-water-clusters/ - waterclusters3_30 = dict( - dataset_name="waterclusters3_30", - links={"W3-W30_all_geoms_TTM2.1-F.zip": "https://drive.google.com/uc?id=18Y7OiZXSCTsHrQ83GCc4fyE_abbL6E_n"}, - ) - - geom = dict( - dataset_name="geom", - links={"rdkit_folder.tar.gz": "https://dataverse.harvard.edu/api/access/datafile/4327252"}, - ) - - l7 = dict( - dataset_name="l7", - links={ - "l7.yaml": "http://cuby4.molecular.cz/download_datasets/l7.yaml", - "geometries.tar.gz": "http://cuby4.molecular.cz/download_geometries/L7.tar", - }, - ) - - molecule3d = dict( - dataset_name="molecule3d", - links={"molecule3d.zip": "https://drive.google.com/uc?id=1C_KRf8mX-gxny7kL9ACNCEV4ceu_fUGy"}, - ) - - orbnet_denali = dict( - dataset_name="orbnet_denali", - links={ - "orbnet_denali.tar.gz": "https://figshare.com/ndownloader/files/28672287", - "orbnet_denali_targets.tar.gz": "https://figshare.com/ndownloader/files/28672248", - }, - ) - - qm7x = dict( - dataset_name="qm7x", - links={f"{i}000.xz": f"https://zenodo.org/record/4288677/files/{i}000.xz" for i in range(1, 9)}, - ) - - qmugs = dict( - dataset_name="qmugs", - links={ - "summary.csv": "https://libdrive.ethz.ch/index.php/s/X5vOBNSITAG5vzM/download?path=%2F&files=summary.csv", - "structures.tar.gz": "https://libdrive.ethz.ch/index.php/s/X5vOBNSITAG5vzM/download?path=%2F&files=structures.tar.gz", - }, - ) - - spice = dict( - dataset_name="spice", - links={"SPICE-1.1.4.hdf5": "https://zenodo.org/record/8222043/files/SPICE-1.1.4.hdf5"}, - ) - spicev2 = dict( - dataset_name="spicev2", - links={"spice-2.0.0.hdf5": "https://zenodo.org/records/10835749/files/SPICE-2.0.0.hdf5?download=1"}, - ) - - splinter = dict( - dataset_name="splinter", - links={ - "dimerpairs.0.tar.gz": "https://figshare.com/ndownloader/files/39449167", - "dimerpairs.1.tar.gz": "https://figshare.com/ndownloader/files/40271983", - "dimerpairs.2.tar.gz": "https://figshare.com/ndownloader/files/40271989", - "dimerpairs.3.tar.gz": "https://figshare.com/ndownloader/files/40272001", - "dimerpairs.4.tar.gz": "https://figshare.com/ndownloader/files/40272022", - "dimerpairs.5.tar.gz": "https://figshare.com/ndownloader/files/40552931", - "dimerpairs.6.tar.gz": "https://figshare.com/ndownloader/files/40272040", - "dimerpairs.7.tar.gz": "https://figshare.com/ndownloader/files/40272052", - "dimerpairs.8.tar.gz": "https://figshare.com/ndownloader/files/40272061", - "dimerpairs.9.tar.gz": "https://figshare.com/ndownloader/files/40272064", - "dimerpairs_nonstandard.tar.gz": "https://figshare.com/ndownloader/files/40272067", - "lig_interaction_sites.sdf": "https://figshare.com/ndownloader/files/40272070", - "lig_monomers.sdf": "https://figshare.com/ndownloader/files/40272073", - "prot_interaction_sites.sdf": "https://figshare.com/ndownloader/files/40272076", - "prot_monomers.sdf": "https://figshare.com/ndownloader/files/40272079", - "merge_monomers.py": "https://figshare.com/ndownloader/files/41807682", - }, - ) - - des370k = dict( - dataset_name="des370k_interaction", - links={ - "DES370K.zip": "https://zenodo.org/record/5676266/files/DES370K.zip", - }, - ) - - des5m = dict( - dataset_name="des5m_interaction", - links={ - "DES5M.zip": "https://zenodo.org/records/5706002/files/DESS5M.zip?download=1", - }, - ) - - tmqm = dict( - dataset_name="tmqm", - links={ - x: f"https://raw.githubusercontent.com/bbskjelstad/tmqm/master/data/{x}" - for x in ["tmQM_X1.xyz.gz", "tmQM_X2.xyz.gz", "tmQM_y.csv", "Benchmark2_TPSSh_Opt.xyz"] - }, - ) - - metcalf = dict( - dataset_name="metcalf", - links={"model-data.tar.gz": "https://zenodo.org/records/10934211/files/model-data.tar?download=1"}, - ) - - misato = dict( - dataset_name="misato", - links={ - "MD.hdf5": "https://zenodo.org/record/7711953/files/MD.hdf5", - "QM.hdf5": "https://zenodo.org/record/7711953/files/QM.hdf5", - }, - ) - - nabladft = dict( - dataset_name="nabladft", - links={"nabladft.db": "https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_full.db"}, - cmd=[ - "axel -n 10 --output=dataset_full.db https://n-usr-31b1j.s3pd12.sbercloud.ru/b-usr-31b1j-qz9/data/moses_db/dataset_full.db" - ], - ) - - pubchemqc = dict( - dataset_name="pubchemqc", - links={ - "pqcm_b3lyp_2017.tar.gz": "https://chibakoudai.sharepoint.com/:u:/s/stair02/Ed9Z16k0ctJKk9nQLMYFHYUBp_E9zerPApRaWTrOIYN-Eg" - }, - cmd=[ - 'wget "https://chibakoudai.sharepoint.com/:u:/s/stair06/EcWMtOpIEqFLrHcR1dzlZiMBLhTFY0RZ0qPaqC4lhRp51A?download=1" -O b3lyp_pm6_ver1.0.1-postgrest-docker-compose.tar.xz.rclone_chunk.001', - 'wget "https://chibakoudai.sharepoint.com/:u:/s/stair06/EbJe-SlL4oNPhOpOtA8mxLsB1F3eI2l-5RS315hIZUFNwQ?download=1" -O b3lyp_pm6_ver1.0.1-postgrest-docker-compose.tar.xz.rclone_chunk.002', - "cat b3lyp_pm6_ver1.0.1-postgrest-docker-compose.tar.xz.rclone_chunk.001 b3lyp_pm6_ver1.0.1-postgrest-docker-compose.tar.xz.rclone_chunk.002 | tar xvfJ - ", - ], - ) - - multixcqm9 = dict( - dataset_name="multixcqm9", - links={ - "xyz.zip": "https://data.dtu.dk/ndownloader/files/35143624", - "xtb.zip": "https://data.dtu.dk/ndownloader/files/42444300", - "dzp.zip": "https://data.dtu.dk/ndownloader/files/42443925", - "tzp.zip": "https://data.dtu.dk/ndownloader/files/42444129", - "sz.zip": "https://data.dtu.dk/ndownloader/files/42441345", - "failed_indices.dat": "https://data.dtu.dk/ndownloader/files/37337677", - }, - ) - - transition1x = dict( - dataset_name="transition1x", - links={"Transition1x.h5": "https://figshare.com/ndownloader/files/36035789"}, - ) - - dess66 = dict( - dataset_name="des_s66", - links={"DESS66.zip": "https://zenodo.org/records/5676284/files/DESS66.zip?download=1"}, - ) - - dess66x8 = dict( - dataset_name="des_s66x8", - links={"DESS66x8.zip": "https://zenodo.org/records/5676284/files/DESS66x8.zip?download=1"}, - ) - revmd17 = dict( - dataset_name="revmd17", - links={"revmd17.zip": "https://figshare.com/ndownloader/articles/12672038/versions/3"}, - ) - md22 = dict( - dataset_name="md22", - links={ - f"{x}.npz": f"http://www.quantum-machine.org/gdml/repo/datasets/md22_{x}.npz" - for x in [ - "Ac-Ala3-NHMe", - "DHA", - "stachyose", - "AT-AT", - "AT-AT-CG-CG", - "double-walled_nanotube", - "buckyball-catcher", - ] - }, - ) - - x40 = dict( - dataset_name="x40", - links={ - "x40.yaml": "http://cuby4.molecular.cz/download_datasets/x40.yaml", - "geometries.tar.gz": "http://cuby4.molecular.cz/download_geometries/X40.tar", - }, - ) - - available_datasets = [k for k in locals().keys() if not k.startswith("__")] - - def __init__(self): - pass - - def __call__(self, dataset_name): - return getattr(self, dataset_name) - - -class DataDownloader: - """Download data from a remote source. - Parameters - ---------- - cache_path : str - Path to the cache directory. - overwrite : bool - Whether to overwrite existing files. - """ - - def __init__(self, cache_path=None, overwrite=False): - if cache_path is None: - cache_path = get_local_cache() - - self.cache_path = cache_path - self.overwrite = overwrite - - def from_config(self, config: dict): - b_config = Bunch(**config) - data_path = os.path.join(self.cache_path, b_config.dataset_name) - os.makedirs(data_path, exist_ok=True) - - logger.info(f"Downloading the {b_config.dataset_name} dataset") - for local, link in b_config.links.items(): - outfile = os.path.join(data_path, local) - - fetch_file(link, outfile) - - def from_name(self, name): - cfg = DataConfigFactory()(name) - return self.from_config(cfg) diff --git a/openqdc/utils/download_api.py b/openqdc/utils/download_api.py new file mode 100644 index 0000000..99343db --- /dev/null +++ b/openqdc/utils/download_api.py @@ -0,0 +1,179 @@ +import gzip +import os +import shutil +import socket +import tarfile +import urllib.error +import urllib.request +import zipfile + +import fsspec +import gdown +import requests +import tqdm +from loguru import logger +from sklearn.utils import Bunch + +from openqdc.utils.io import get_local_cache + + +def download_url(url, local_filename): + """ + Download a file from a url to a local file. + Parameters + ---------- + url : str + URL to download from. + local_filename : str + Local path for destination. + """ + logger.info(f"Url: {url} File: {local_filename}") + if "drive.google.com" in url: + gdown.download(url, local_filename, quiet=False) + elif "raw.github" in url: + r = requests.get(url, allow_redirects=True) + with open(local_filename, "wb") as f: + f.write(r.content) + else: + r = requests.get(url, stream=True) + with fsspec.open(local_filename, "wb") as f: + for chunk in tqdm.tqdm(r.iter_content(chunk_size=16384)): + if chunk: + f.write(chunk) + + +def decompress_tar_gz(local_filename): + """ + Decompress a tar.gz file. + Parameters + ---------- + local_filename : str + Path to local file to decompress. + """ + parent = os.path.dirname(local_filename) + with tarfile.open(local_filename) as tar: + logger.info(f"Verifying archive extraction states: {local_filename}") + all_names = tar.getnames() + all_extracted = all([os.path.exists(os.path.join(parent, x)) for x in all_names]) + if not all_extracted: + logger.info(f"Extracting archive: {local_filename}") + tar.extractall(path=parent) + else: + logger.info(f"Archive already extracted: {local_filename}") + + +def decompress_zip(local_filename): + """ + Decompress a zip file. + Parameters + ---------- + local_filename : str + Path to local file to decompress. + """ + parent = os.path.dirname(local_filename) + + logger.info(f"Verifying archive extraction states: {local_filename}") + with zipfile.ZipFile(local_filename, "r") as zip_ref: + all_names = zip_ref.namelist() + all_extracted = all([os.path.exists(os.path.join(parent, x)) for x in all_names]) + if not all_extracted: + logger.info(f"Extracting archive: {local_filename}") + zip_ref.extractall(parent) + else: + logger.info(f"Archive already extracted: {local_filename}") + + +def decompress_gz(local_filename): + """ + Decompress a gz file. + Parameters + ---------- + local_filename : str + Path to local file to decompress. + """ + logger.info(f"Verifying archive extraction states: {local_filename}") + out_filename = local_filename.replace(".gz", "") + if out_filename.endswith("hdf5"): + out_filename = local_filename.replace("hdf5", "h5") + + all_extracted = os.path.exists(out_filename) + if not all_extracted: + logger.info(f"Extracting archive: {local_filename}") + with gzip.open(local_filename, "rb") as f_in, open(out_filename, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + else: + logger.info(f"Archive already extracted: {local_filename}") + + +def fetch_file(url, local_filename, overwrite=False): + """ + Download a file from a url to a local file. Useful for big files. + Parameters + ---------- + url : str + URL to download from. + local_filename : str + Local file to save to. + overwrite : bool + Whether to overwrite existing files. + Returns + ------- + local_filename : str + Local file. + """ + try: + if os.path.exists(local_filename) and not overwrite: + logger.info("File already exists, skipping download") + else: + download_url(url, local_filename) + + # decompress archive if necessary + parent = os.path.dirname(local_filename) + if local_filename.endswith("tar.gz"): + decompress_tar_gz(local_filename) + + elif local_filename.endswith("zip"): + decompress_zip(local_filename) + + elif local_filename.endswith(".gz"): + decompress_gz(local_filename) + + elif local_filename.endswith("xz"): + logger.info(f"Extracting archive: {local_filename}") + os.system(f"cd {parent} && xz -d *.xz") + + else: + pass + + except (socket.gaierror, urllib.error.URLError) as err: + raise ConnectionError("Could not download {} due to {}".format(url, err)) + + return local_filename + + +class DataDownloader: + """Download data from a remote source. + Parameters + ---------- + cache_path : str + Path to the cache directory. + overwrite : bool + Whether to overwrite existing files. + """ + + def __init__(self, cache_path=None, overwrite=False): + if cache_path is None: + cache_path = get_local_cache() + + self.cache_path = cache_path + self.overwrite = overwrite + + def from_config(self, config: dict): + b_config = Bunch(**config) + data_path = os.path.join(self.cache_path, b_config.dataset_name) + os.makedirs(data_path, exist_ok=True) + + logger.info(f"Downloading the {b_config.dataset_name} dataset") + for local, link in b_config.links.items(): + outfile = os.path.join(data_path, local) + fetch_file(link, outfile) diff --git a/openqdc/utils/regressor.py b/openqdc/utils/regressor.py index c230ce7..1d3e50a 100644 --- a/openqdc/utils/regressor.py +++ b/openqdc/utils/regressor.py @@ -146,6 +146,8 @@ def solve(self): else: X, y = self.X, self.y[:, energy_idx] E0s, cov = self.solver(X, y) + if cov is None: + cov = np.zeros_like(E0s) + 1.0 E0_list.append(E0s) cov_list.append(cov) return np.vstack(E0_list).T, np.vstack(cov_list).T diff --git a/tests/test_energies.py b/tests/test_energies.py new file mode 100644 index 0000000..47a8e7b --- /dev/null +++ b/tests/test_energies.py @@ -0,0 +1,44 @@ +import numpy as np +import pytest + +from openqdc.datasets.energies import AtomEnergies, AtomEnergy +from openqdc.methods import PotentialMethod + + +class Container: + __name__ = "container" + __energy_methods__ = [PotentialMethod.WB97M_D3BJ_DEF2_TZVPPD] + energy_methods = [str(PotentialMethod.WB97M_D3BJ_DEF2_TZVPPD)] + refit_e0s = True + + def __init__(self, energy_type="formation"): + self.energy_type = energy_type + + +@pytest.fixture +def physical_energies(): + dummy = Container() + return AtomEnergies(dummy) + + +def test_atom_energies_object(physical_energies): + assert isinstance(physical_energies, AtomEnergies) + + +def test_indexing(physical_energies): + assert isinstance(physical_energies[6], AtomEnergy) + assert isinstance(physical_energies[(6, 1)], AtomEnergy) + assert isinstance(physical_energies[6, 1], AtomEnergy) + assert isinstance(physical_energies[("C", 1)], AtomEnergy) + assert isinstance(physical_energies["C", 1], AtomEnergy) + assert physical_energies[("C", 1)] == physical_energies[(6, 1)] + assert not physical_energies[("Cl", -2)] == physical_energies[(6, 1)] + with pytest.raises(KeyError): + physical_energies[("Cl", -6)] + + +def test_matrix(physical_energies): + matrix = physical_energies.e0s_matrix + assert len(matrix) == 1 + assert isinstance(matrix, np.ndarray) + assert np.any(matrix)