Skip to content

Commit

Permalink
Merge pull request #51 from OpenDrugDiscovery/SpiceV2
Browse files Browse the repository at this point in the history
SpiceV2
  • Loading branch information
FNTwin authored Mar 20, 2024
2 parents 3df21cc + 6451f6f commit dddcea2
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 7 deletions.
3 changes: 2 additions & 1 deletion openqdc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"ANI1CCX": "openqdc.datasets.potential.ani",
"ANI1X": "openqdc.datasets.potential.ani",
"Spice": "openqdc.datasets.potential.spice",
"SpiceV2": "openqdc.datasets.potential.spice",
"GEOM": "openqdc.datasets.potential.geom",
"QMugs": "openqdc.datasets.potential.qmugs",
"ISO17": "openqdc.datasets.potential.iso_17",
Expand Down Expand Up @@ -86,7 +87,7 @@ def __dir__():
from .datasets.potential.revmd17 import RevMD17 # noqa
from .datasets.potential.sn2_rxn import SN2RXN # noqa
from .datasets.potential.solvated_peptides import SolvatedPeptides # noqa
from .datasets.potential.spice import Spice # noqa
from .datasets.potential.spice import Spice, SpiceV2 # noqa
from .datasets.potential.tmqm import TMQM # noqa
from .datasets.potential.transition1x import Transition1X # noqa
from .datasets.potential.waterclusters3_30 import WaterClusters # noqa
3 changes: 2 additions & 1 deletion openqdc/datasets/potential/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .revmd17 import RevMD17 # noqa
from .sn2_rxn import SN2RXN # noqa
from .solvated_peptides import SolvatedPeptides # noqa
from .spice import Spice # noqa
from .spice import Spice, SpiceV2 # noqa
from .tmqm import TMQM # noqa
from .transition1x import Transition1X # noqa
from .waterclusters3_30 import WaterClusters # noqa
Expand All @@ -37,6 +37,7 @@
"sn2rxn": SN2RXN,
"solvatedpeptides": SolvatedPeptides,
"spice": Spice,
"spicev2": SpiceV2,
"tmqm": TMQM,
"transition1x": Transition1X,
"watercluster": WaterClusters,
Expand Down
69 changes: 64 additions & 5 deletions openqdc/datasets/potential/spice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from openqdc.utils.molecule import get_atomic_number_and_charge


def read_record(r):
def read_record(r, obj):
"""
Read record from hdf5 file.
r : hdf5 record
obj : Spice class object used to grab subset and names
"""
smiles = r["smiles"].asstr()[0]
subset = r["subset"][0].decode("utf-8")
n_confs = r["conformations"].shape[0]
Expand All @@ -18,9 +23,9 @@ def read_record(r):

res = dict(
name=np.array([smiles] * n_confs),
subset=np.array([Spice.subset_mapping[subset]] * n_confs),
energies=r[Spice.energy_target_names[0]][:][:, None].astype(np.float32),
forces=r[Spice.force_target_names[0]][:].reshape(
subset=np.array([obj.subset_mapping[subset]] * n_confs),
energies=r[obj.energy_target_names[0]][:][:, None].astype(np.float32),
forces=r[obj.force_target_names[0]][:].reshape(
-1, 3, 1
), # forces -ve of energy gradient but the -1.0 is done in the convert_forces method
atomic_inputs=np.concatenate(
Expand Down Expand Up @@ -82,6 +87,60 @@ def read_raw_entries(self):
raw_path = p_join(self.root, "SPICE-1.1.4.hdf5")

data = load_hdf5_file(raw_path)
tmp = [read_record(data[mol_name]) for mol_name in tqdm(data)] # don't use parallelized here
tmp = [read_record(data[mol_name], self) for mol_name in tqdm(data)] # don't use parallelized here

return tmp


class SpiceV2(Spice):
"""
SpiceV2 dataset augmented with amino acids complexes, water boxes,
pubchem solvated molecules.
It consists of both forces and energies calculated
at the {\omega}B97M-D3(BJ)/def2-TZVPPD level of theory.
Usage:
```python
from openqdc.datasets import SpiceV2
dataset = SpiceV2()
```
References:
- https://github.com/openmm/spice-dataset/releases/tag/2.0.0
- https://github.com/openmm/spice-dataset
"""

__name__ = "spicev2"

subset_mapping = {
"SPICE Dipeptides Single Points Dataset v1.3": "Dipeptides",
"SPICE Solvated Amino Acids Single Points Dataset v1.1": "Solvated Amino Acids",
"SPICE Water Clusters v1.0": "Water Clusters",
"SPICE Solvated PubChem Set 1 v1.0": "Solvated PubChem",
"SPICE Amino Acid Ligand v1.0": "Amino Acid Ligand",
"SPICE PubChem Set 1 Single Points Dataset v1.3": "PubChem",
"SPICE PubChem Set 2 Single Points Dataset v1.3": "PubChem",
"SPICE PubChem Set 3 Single Points Dataset v1.3": "PubChem",
"SPICE PubChem Set 4 Single Points Dataset v1.3": "PubChem",
"SPICE PubChem Set 5 Single Points Dataset v1.3": "PubChem",
"SPICE PubChem Set 6 Single Points Dataset v1.3": "PubChem",
"SPICE PubChem Set 7 Single Points Dataset v1.0": "PubChem",
"SPICE PubChem Set 8 Single Points Dataset v1.0": "PubChem",
"SPICE PubChem Set 9 Single Points Dataset v1.0": "PubChem",
"SPICE PubChem Set 10 Single Points Dataset v1.0": "PubChem",
"SPICE DES Monomers Single Points Dataset v1.1": "DES370K Monomers",
"SPICE DES370K Single Points Dataset v1.0": "DES370K Dimers",
"SPICE DES370K Single Points Dataset Supplement v1.1": "DES370K Dimers",
"SPICE PubChem Boron Silicon v1.0": "PubChem Boron Silicon",
"SPICE Ion Pairs Single Points Dataset v1.2": "Ion Pairs",
}

def read_raw_entries(self):
raw_path = p_join(self.root, "spice-2.0.0.hdf5")

data = load_hdf5_file(raw_path)
# Entry 40132 without positions, skip it
# don't use parallelized here
tmp = [read_record(data[mol_name], self) for i, mol_name in enumerate(tqdm(data)) if i != 40132]

return tmp
4 changes: 4 additions & 0 deletions openqdc/raws/config_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ class DataConfigFactory:
dataset_name="spice",
links={"SPICE-1.1.4.hdf5": "https://zenodo.org/record/8222043/files/SPICE-1.1.4.hdf5"},
)
spicev2 = dict(
dataset_name="spicev2",
links={"spice-2.0.0.hdf5": "https://zenodo.org/records/10835749/files/SPICE-2.0.0.hdf5?download=1"},
)

dess = dict(
dataset_name="dess5m",
Expand Down

0 comments on commit dddcea2

Please sign in to comment.