Skip to content

Commit

Permalink
Merge pull request #86 from OpenDrugDiscovery/downloader_add
Browse files Browse the repository at this point in the history
unique enums + initial structure for api endpoint
  • Loading branch information
prtos authored Jun 8, 2024
2 parents 2669d21 + c2d13ad commit 6ffa749
Show file tree
Hide file tree
Showing 21 changed files with 334 additions and 132 deletions.
3 changes: 2 additions & 1 deletion openqdc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def get_project_root():
"ANI1CCX": "openqdc.datasets.potential.ani",
"ANI1CCX_V2": "openqdc.datasets.potential.ani",
"ANI1X": "openqdc.datasets.potential.ani",
"ANI2": "openqdc.datasets.potential.ani",
"Spice": "openqdc.datasets.potential.spice",
"SpiceV2": "openqdc.datasets.potential.spice",
"SpiceVL2": "openqdc.datasets.potential.spice",
Expand Down Expand Up @@ -100,7 +101,7 @@ def __dir__():
from .datasets.interaction.metcalf import Metcalf
from .datasets.interaction.splinter import Splinter
from .datasets.interaction.x40 import X40
from .datasets.potential.ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X
from .datasets.potential.ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X, ANI2
from .datasets.potential.comp6 import COMP6
from .datasets.potential.dummy import Dummy
from .datasets.potential.gdml import GDML
Expand Down
26 changes: 16 additions & 10 deletions openqdc/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@
app = typer.Typer(help="OpenQDC CLI")


def sanitize(dictionary):
return {k.lower().replace("_", "").replace("-", ""): v for k, v in dictionary.items()}


SANITIZED_AVAILABLE_DATASETS = sanitize(AVAILABLE_DATASETS)


def exist_dataset(dataset):
if dataset not in AVAILABLE_DATASETS:
if dataset not in sanitize(AVAILABLE_DATASETS):
logger.error(f"{dataset} is not available. Please open an issue on Github for the team to look into it.")
return False
return True
Expand Down Expand Up @@ -57,10 +64,10 @@ def download(
"""
for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
if exist_dataset(dataset):
if AVAILABLE_DATASETS[dataset].no_init().is_cached() and not overwrite:
if SANITIZED_AVAILABLE_DATASETS[dataset].no_init().is_cached() and not overwrite:
logger.info(f"{dataset} is already cached. Skipping download")
else:
AVAILABLE_DATASETS[dataset](overwrite_local_cache=True, cache_dir=cache_dir)
SANITIZED_AVAILABLE_DATASETS[dataset](overwrite_local_cache=True, cache_dir=cache_dir)


@app.command()
Expand Down Expand Up @@ -115,18 +122,17 @@ def fetch(
openqdc fetch Spice
"""
if datasets[0].lower() == "all":
dataset_names = AVAILABLE_DATASETS
dataset_names = list(sanitize(AVAILABLE_DATASETS).keys())
elif datasets[0].lower() == "potential":
dataset_names = AVAILABLE_POTENTIAL_DATASETS
dataset_names = list(sanitize(AVAILABLE_POTENTIAL_DATASETS).keys())
elif datasets[0].lower() == "interaction":
dataset_names = AVAILABLE_INTERACTION_DATASETS
dataset_names = list(sanitize(AVAILABLE_INTERACTION_DATASETS).keys())
else:
dataset_names = datasets

for dataset in list(map(lambda x: x.lower().replace("_", ""), dataset_names)):
if exist_dataset(dataset):
try:
AVAILABLE_DATASETS[dataset].fetch(cache_dir, overwrite)
SANITIZED_AVAILABLE_DATASETS[dataset].fetch(cache_dir, overwrite)
except Exception as e:
logger.error(f"Something unexpected happended while fetching {dataset}: {repr(e)}")

Expand All @@ -152,9 +158,9 @@ def preprocess(
"""
for dataset in list(map(lambda x: x.lower().replace("_", ""), datasets)):
if exist_dataset(dataset):
logger.info(f"Preprocessing {AVAILABLE_DATASETS[dataset].__name__}")
logger.info(f"Preprocessing {SANITIZED_AVAILABLE_DATASETS[dataset].__name__}")
try:
AVAILABLE_DATASETS[dataset].no_init().preprocess(upload=upload, overwrite=overwrite)
SANITIZED_AVAILABLE_DATASETS[dataset].no_init().preprocess(upload=upload, overwrite=overwrite)
except Exception as e:
logger.error(f"Error while preprocessing {dataset}. {e}. Did you fetch the dataset first?")
raise e
Expand Down
6 changes: 3 additions & 3 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
energy_unit: Optional[str] = None,
distance_unit: Optional[str] = None,
array_format: str = "numpy",
energy_type: str = "formation",
energy_type: Optional[str] = "formation",
overwrite_local_cache: bool = False,
cache_dir: Optional[str] = None,
recompute_statistics: bool = False,
Expand All @@ -112,7 +112,7 @@ def __init__(
Format to return arrays in. Supported formats: ["numpy", "torch", "jax"]
energy_type
Type of isolated atom energy to use for the dataset. Default: "formation"
Supported types: ["formation", "regression", "null"]
Supported types: ["formation", "regression", "null", None]
overwrite_local_cache
Whether to overwrite the locally cached dataset.
cache_dir
Expand All @@ -133,7 +133,7 @@ def __init__(
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
self.transform = transform
self.energy_type = energy_type
self.energy_type = energy_type if energy_type is not None else "null"
self.refit_e0s = recompute_statistics or overwrite_local_cache
if not self.is_preprocessed():
raise DatasetNotAvailableError(self.__name__)
Expand Down
3 changes: 1 addition & 2 deletions openqdc/datasets/energies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
from loguru import logger

from openqdc.methods.enums import PotentialMethod
from openqdc.utils.constants import ATOM_SYMBOLS, ATOMIC_NUMBERS
from openqdc.utils.constants import ATOM_SYMBOLS, ATOMIC_NUMBERS, MAX_CHARGE_NUMBER
from openqdc.utils.io import load_pkl, save_pkl
from openqdc.utils.regressor import Regressor

POSSIBLE_ENERGIES = ["formation", "regression", "null"]
MAX_CHARGE_NUMBER = 21


def dispatch_factory(data, **kwargs) -> "IsolatedEnergyInterface":
Expand Down
16 changes: 8 additions & 8 deletions openqdc/datasets/interaction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
from .x40 import X40

AVAILABLE_INTERACTION_DATASETS = {
"des5m": DES5M,
"des370k": DES370K,
"dess66": DESS66,
"dess66x8": DESS66x8,
"l7": L7,
"metcalf": Metcalf,
"splinter": Splinter,
"x40": X40,
"DES5M": DES5M,
"DES370K": DES370K,
"DESS66": DESS66,
"DESS66x8": DESS66x8,
"L7": L7,
"Metcalf": Metcalf,
"Splinter": Splinter,
"X40": X40,
}
60 changes: 30 additions & 30 deletions openqdc/datasets/potential/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X
from .ani import ANI1, ANI1CCX, ANI1CCX_V2, ANI1X, ANI2
from .comp6 import COMP6
from .dummy import Dummy
from .gdml import GDML
Expand All @@ -21,33 +21,33 @@
from .waterclusters3_30 import WaterClusters

AVAILABLE_POTENTIAL_DATASETS = {
"ani1": ANI1,
"ani1ccx": ANI1CCX,
"ani1ccxv2": ANI1CCX_V2,
"ani1x": ANI1X,
"comp6": COMP6,
"gdml": GDML,
"geom": GEOM,
"iso17": ISO17,
"molecule3d": Molecule3D,
"nabladft": NablaDFT,
"orbnetdenali": OrbnetDenali,
"pcqmb3lyp": PCQM_B3LYP,
"pcqmpm6": PCQM_PM6,
"qm7x": QM7X,
"qm7xv2": QM7X_V2,
"qmugs": QMugs,
"qmugsv2": QMugs_V2,
"sn2rxn": SN2RXN,
"solvatedpeptides": SolvatedPeptides,
"spice": Spice,
"spicev2": SpiceV2,
"spicevl2": SpiceVL2,
"tmqm": TMQM,
"transition1x": Transition1X,
"watercluster": WaterClusters,
"multixcqm9": MultixcQM9,
"multixcqm9v2": MultixcQM9_V2,
"revmd17": RevMD17,
"md22": MD22,
"ANI1": ANI1,
"ANI1CCX": ANI1CCX,
"ANI1CCX_V2": ANI1CCX_V2,
"ANI1X": ANI1X,
"COMP6": COMP6,
"GDML": GDML,
"GEOM": GEOM,
"ISO17": ISO17,
"Molecule3D": Molecule3D,
"NablaDFT": NablaDFT,
"OrbnetDenali": OrbnetDenali,
"PCQM_B3LYP": PCQM_B3LYP,
"PCQM_PM6": PCQM_PM6,
"QM7X": QM7X,
"QM7X_V2": QM7X_V2,
"QMugs": QMugs,
"QMugs_V2": QMugs_V2,
"SN2RXN": SN2RXN,
"SolvatedPeptides": SolvatedPeptides,
"Spice": Spice,
"SpiceV2": SpiceV2,
"SpiceVL2": SpiceVL2,
"TMQM": TMQM,
"Transition1X": Transition1X,
"WaterClusters": WaterClusters,
"MultixcQM9": MultixcQM9,
"MultixcQM9_V2": MultixcQM9_V2,
"RevMD17": RevMD17,
"MD22": MD22,
}
80 changes: 79 additions & 1 deletion openqdc/datasets/potential/ani.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,42 @@
import os
from os.path import join as p_join

import numpy as np

from openqdc.datasets.base import BaseDataset
from openqdc.methods import PotentialMethod
from openqdc.utils import read_qc_archive_h5
from openqdc.utils import load_hdf5_file, read_qc_archive_h5
from openqdc.utils.io import get_local_cache


def read_ani2_h5(raw_path):
h5f = load_hdf5_file(raw_path)
samples = []
for _, props in h5f.items():
samples.append(extract_ani2_entries(props))
return samples


def extract_ani2_entries(properties):
coordinates = properties["coordinates"]
species = properties["species"]
forces = properties["forces"]
energies = properties["energies"]
n_atoms = coordinates.shape[1]
n_entries = coordinates.shape[0]
flattened_coordinates = coordinates[:].reshape((-1, 3))
xs = np.stack((species[:].flatten(), np.zeros(flattened_coordinates.shape[0])), axis=-1)
res = dict(
name=np.array(["ANI2"] * n_entries),
subset=np.array([str(n_atoms)] * n_entries),
energies=energies[:].reshape((-1, 1)).astype(np.float64),
atomic_inputs=np.concatenate((xs, flattened_coordinates), axis=-1, dtype=np.float32),
n_atoms=np.array([n_atoms] * n_entries, dtype=np.int32),
forces=forces[:].reshape(-1, 3, 1).astype(np.float32),
)
return res


class ANI1(BaseDataset):
"""
The ANI-1 dataset is a collection of 22 x 10^6 structural conformations from 57,000 distinct small
Expand Down Expand Up @@ -176,3 +206,51 @@ class ANI1CCX_V2(ANI1CCX):

__energy_methods__ = ANI1CCX.__energy_methods__ + [PotentialMethod.PM6, PotentialMethod.GFN2_XTB]
energy_target_names = ANI1CCX.energy_target_names + ["PM6", "GFN2"]


class ANI2(ANI1):
""" """

__name__ = "ani2"
__energy_unit__ = "hartree"
__distance_unit__ = "ang"
__forces_unit__ = "hartree/ang"

__energy_methods__ = [
# PotentialMethod.NONE, # "b973c/def2mtzvp",
PotentialMethod.WB97X_6_31G_D, # "wb97x/631gd", # PAPER DATASET
# PotentialMethod.NONE, # "wb97md3bj/def2tzvpp",
# PotentialMethod.NONE, # "wb97mv/def2tzvpp",
# PotentialMethod.NONE, # "wb97x/def2tzvpp",
]

energy_target_names = [
# "b973c/def2mtzvp",
"wb97x/631gd",
# "wb97md3bj/def2tzvpp",
# "wb97mv/def2tzvpp",
# "wb97x/def2tzvpp",
]

force_target_names = ["wb97x/631gd"] # "b973c/def2mtzvp",

__force_mask__ = [True]
__links__ = { # "ANI-2x-B973c-def2mTZVP.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-B973c-def2mTZVP.tar.gz?download=1", # noqa
# "ANI-2x-wB97MD3BJ-def2TZVPP.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-wB97MD3BJ-def2TZVPP.tar.gz?download=1", # noqa
# "ANI-2x-wB97MV-def2TZVPP.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-wB97MV-def2TZVPP.tar.gz?download=1", # noqa
"ANI-2x-wB97X-631Gd.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-wB97X-631Gd.tar.gz?download=1", # noqa
# "ANI-2x-wB97X-def2TZVPP.tar.gz": "https://zenodo.org/records/10108942/files/ANI-2x-wB97X-def2TZVPP.tar.gz?download=1", # noqa
}

def __smiles_converter__(self, x):
"""util function to convert string to smiles: useful if the smiles is
encoded in a different format than its display format
"""
return x

def read_raw_entries(self):
samples = []
for lvl_theory in self.__links__.keys():
raw_path = p_join(self.root, "final_h5", f"{lvl_theory.split('.')[0]}.h5")
samples.extend(read_ani2_h5(raw_path))
return samples
4 changes: 2 additions & 2 deletions openqdc/datasets/potential/comp6.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class COMP6(BaseDataset):

# watchout that forces are stored as -grad(E)
__energy_unit__ = "kcal/mol"
__distance_unit__ = "bohr" # bohr
__forces_unit__ = "kcal/mol/bohr"
__distance_unit__ = "ang" # angstorm
__forces_unit__ = "kcal/mol/ang"

__energy_methods__ = [
PotentialMethod.WB97X_6_31G_D, # "wb97x/6-31g*",
Expand Down
4 changes: 2 additions & 2 deletions openqdc/datasets/potential/gdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class GDML(BaseDataset):
]

__energy_unit__ = "kcal/mol"
__distance_unit__ = "bohr"
__forces_unit__ = "kcal/mol/bohr"
__distance_unit__ = "ang"
__forces_unit__ = "kcal/mol/ang"
__links__ = {
"gdb7_9.hdf5.gz": "https://zenodo.org/record/3588361/files/208.hdf5.gz",
"gdb10_13.hdf5.gz": "https://zenodo.org/record/3588364/files/209.hdf5.gz",
Expand Down
4 changes: 2 additions & 2 deletions openqdc/datasets/potential/iso_17.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class ISO17(BaseDataset):
]

__energy_unit__ = "ev"
__distance_unit__ = "bohr" # bohr
__forces_unit__ = "ev/bohr"
__distance_unit__ = "ang"
__forces_unit__ = "ev/ang"
__links__ = {"iso_17.hdf5.gz": "https://zenodo.org/record/3585907/files/216.hdf5.gz"}

def __smiles_converter__(self, x):
Expand Down
4 changes: 2 additions & 2 deletions openqdc/datasets/potential/molecule3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def read_mol(mol: Chem.rdchem.Mol, energy: float) -> Dict[str, np.ndarray]:
res = dict(
name=np.array([smiles]),
subset=np.array(["molecule3d"]),
energies=np.array([energy]).astype(np.float32)[:, None],
atomic_inputs=np.concatenate((x, positions), axis=-1, dtype=np.float64),
energies=np.array([energy]).astype(np.float64)[:, None],
atomic_inputs=np.concatenate((x, positions), axis=-1, dtype=np.float32),
n_atoms=np.array([x.shape[0]], dtype=np.int32),
)

Expand Down
2 changes: 1 addition & 1 deletion openqdc/datasets/potential/qm7x.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class QM7X(BaseDataset):

__energy_methods__ = [PotentialMethod.PBE0_DEF2_TZVP, PotentialMethod.DFT3B] # "pbe0/def2-tzvp", "dft3b"]

energy_target_names = ["ePBE0", "eMBD"]
energy_target_names = ["ePBE0+MBD", "eDFTB+MBD"]

__force_mask__ = [True, True]

Expand Down
Loading

0 comments on commit 6ffa749

Please sign in to comment.