Skip to content

Commit

Permalink
Merge pull request #10 from OpenDrugDiscovery/adding_datasets
Browse files Browse the repository at this point in the history
Adding datasets
  • Loading branch information
prtos authored Oct 12, 2023
2 parents 7a9905b + f2e1664 commit 035fabb
Show file tree
Hide file tree
Showing 13 changed files with 686 additions and 75 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ nohup.out
*.xyz
*.csv
*.txt
*.sh
110 changes: 110 additions & 0 deletions src/openqdc/datasets/dess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from os.path import join as p_join

import datamol as dm
import numpy as np
import pandas as pd
from tqdm import tqdm

from openqdc.datasets.base import BaseDataset
from openqdc.utils.constants import MAX_ATOMIC_NUMBER
from openqdc.utils.molecule import get_atomic_number_and_charge


def read_mol(mol_path, smiles, subset, targets):
try:
with open(mol_path, "r") as f:
mol_block = f.read()
mol = dm.read_molblock(mol_block, remove_hs=False, fail_if_invalid=True)

x = get_atomic_number_and_charge(mol)
positions = mol.GetConformer().GetPositions()

res = dict(
name=np.array([smiles]),
subset=np.array([subset]),
energies=np.array(targets).astype(np.float32)[None, :],
atomic_inputs=np.concatenate((x, positions), axis=-1, dtype=np.float32),
n_atoms=np.array([x.shape[0]], dtype=np.int32),
)
except Exception as e:
print(f"Skipping: {mol_path} due to {e}")
res = None

return res


class DESS(BaseDataset):
__name__ = "dess"
__energy_methods__ = [
"mp2_cc",
"mp2_qz",
"mp2_tz",
"mp2_cbs",
"ccsd(t)_cc",
"ccsd(t)_cbs",
"ccsd(t)_nn",
"sapt",
]

energy_target_names = [
"cc_MP2_all",
"qz_MP2_all",
"tz_MP2_all",
"cbs_MP2_all",
"cc_CCSD(T)_all",
"cbs_CCSD(T)_all",
"nn_CCSD(T)_all",
"sapt_all",
]
# ['qz_MP2_all', 'tz_MP2_all', 'cbs_MP2_all', 'sapt_all', 'nn_CCSD(T)_all']

# Energy in hartree, all zeros by default
atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32)

partitions = ["DES370K", "DES5M"]

def __init__(self) -> None:
super().__init__()

def _read_raw_(self, part):
df = pd.read_csv(p_join(self.root, f"{part}.csv"))
for col in self.energy_target_names:
if col not in df.columns:
df[col] = np.nan
smiles = (df["smiles0"] + "." + df["smiles1"]).tolist()
subsets = (f"{part}_" + df["group_orig"]).tolist()
targets = df[self.energy_target_names].values
paths = (
p_join(self.root, "geometries/")
+ df["system_id"].astype(str)
+ f"/{part}_"
+ df["geom_id"].astype(str)
+ ".mol"
)

inputs = [
dict(smiles=smiles[i], subset=subsets[i], targets=targets[i], mol_path=paths[i])
for i in tqdm(range(len(smiles)))
]
f = lambda xs: [read_mol(**x) for x in xs]
samples = dm.parallelized_with_batches(
f, inputs, n_jobs=-1, progress=True, batch_size=1024, scheduler="threads"
)
return samples

def read_raw_entries(self):
samples = sum([self._read_raw_(partition) for partition in self.partitions], [])
return samples


if __name__ == "__main__":
for data_class in [DESS]:
data = data_class()
n = len(data)

for i in np.random.choice(n, 3, replace=False):
x = data[i]
print(x.name, x.subset, end=" ")
for k in x:
if x[k] is not None:
print(k, x[k].shape, end=" ")
4 changes: 2 additions & 2 deletions src/openqdc/datasets/orbnet_denali.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from openqdc.utils.molecule import atom_table


def read_mol(mol_id, conf_dict, base_path, energy_target_names):
def read_archive(mol_id, conf_dict, base_path, energy_target_names):
res = []
for conf_id, conf_label in conf_dict.items():
try:
Expand Down Expand Up @@ -60,7 +60,7 @@ def read_raw_entries(self):
# if i > 10:
# break
# exit()
fn = lambda x: read_mol(x[0], x[1], self.root, self.energy_target_names)
fn = lambda x: read_archive(x[0], x[1], self.root, self.energy_target_names)
res = dm.parallelized(fn, list(labels.items()), scheduler="threads", n_jobs=-1, progress=True)
samples = sum(res, [])
return samples
Expand Down
94 changes: 94 additions & 0 deletions src/openqdc/datasets/pcqm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import json
import tarfile
from glob import glob
from os.path import join as p_join

import datamol as dm
import numpy as np
import pandas as pd

from openqdc.datasets.base import BaseDataset
from openqdc.utils.constants import MAX_ATOMIC_NUMBER


def flatten_dict(d, sep: str = "."):
return pd.json_normalize(d, sep=sep).to_dict(orient="records")[0]


def read_content(f):
try:
r = flatten_dict(json.load(f))
x = np.concatenate(
(
r["atoms.elements.number"][:, None],
r["atoms.core electrons"][:, None],
r["atoms.coords.3d"].reshape(-1, 3),
),
axis=-1,
).astype(np.float32)

res = dict(
name=np.array([r["smiles"]]),
subset=np.array([r["formula"]]),
energies=np.array(["properties.energy.total"]).astype(np.float32)[None, :],
atomic_inputs=x,
n_atoms=np.array([x.shape[0]], dtype=np.int32),
)
except Exception:
res = None

return res


def read_archive(path):
with tarfile.open(path) as tar:
res = [read_content(tar.extractfile(member)) for member in tar.getmembers()]
# print(len(res))
return res


class PubchemQC(BaseDataset):
__name__ = "pubchemqc"
__energy_methods__ = [
"b3lyp",
"pm6",
]

energy_target_names = [
"b3lyp",
"pm6",
]

# Energy in hartree, all zeros by default
atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32)

partitions = ["b3lyp", "pm6"]

def __init__(self) -> None:
super().__init__()

def _read_raw_(self, part):
arxiv_paths = glob(p_join(self.root, f"{part}", "*.tar.gz"))
print(len(arxiv_paths))
samples = dm.parallelized(read_archive, arxiv_paths, n_jobs=-1, progress=True, scheduler="threads")
res = sum(samples, [])
print(len(res))
exit()
return res

def read_raw_entries(self):
samples = sum([self._read_raw_(partition) for partition in self.partitions], [])
return samples


if __name__ == "__main__":
for data_class in [PubchemQC]:
data = data_class()
n = len(data)

for i in np.random.choice(n, 3, replace=False):
x = data[i]
print(x.name, x.subset, end=" ")
for k in x:
if x[k] is not None:
print(k, x[k].shape, end=" ")
62 changes: 42 additions & 20 deletions src/openqdc/datasets/qm7x.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,69 @@
from os.path import join as p_join

import numpy as np
from tqdm import tqdm

from openqdc.datasets.base import BaseDataset, read_qc_archive_h5
from openqdc.datasets.base import BaseDataset
from openqdc.utils.constants import MAX_ATOMIC_NUMBER
from openqdc.utils.io import load_hdf5_file


class ISO17(BaseDataset):
__name__ = "iso_17"
def read_mol(mol_h5, mol_name, energy_target_names, force_target_names):
m = mol_h5
cids = list(mol_h5.keys())

zs = [m[c]["atNUM"] for c in cids]
xyz = np.concatenate([m[c]["atXYZ"] for c in cids], axis=0)
n_atoms = np.array([len(z) for z in zs], dtype=np.int32)
n, zs = len(n_atoms), np.concatenate(zs, axis=0)
a_inputs = np.concatenate([np.stack([zs, np.zeros_like(zs)], axis=-1), xyz], axis=-1)

forces = np.concatenate([np.stack([m[c][f_tag] for f_tag in force_target_names], axis=-1) for c in cids], axis=0)
energies = np.stack([np.array([m[c][e_tag][0] for e_tag in energy_target_names]) for c in cids], axis=0)

res = dict(
name=np.array([mol_name] * n),
subset=np.array(["qm7x"] * n),
energies=energies.astype(np.float32),
atomic_inputs=a_inputs.astype(np.float32),
forces=forces.astype(np.float32),
n_atoms=n_atoms,
)

return res


class QM7X(BaseDataset):
__name__ = "qm7x"

# Energy in hartree, all zeros by default
atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32)

__energy_methods__ = [
"pbe-ts",
]
__energy_methods__ = ["pbe-ts", "mbd"]

energy_target_names = [
"PBE-TS Energy",
]
energy_target_names = ["ePBE0", "eMBD"]

__force_methods__ = [
"pbe-ts",
]
__force_methods__ = ["pbe-ts", "vdw"]

force_target_names = [
"PBE-TS Gradient",
]
force_target_names = ["pbe0FOR", "vdwFOR"]

def __init__(self) -> None:
super().__init__()

def read_raw_entries(self):
raw_path = p_join(self.root, "iso_17.h5")
samples = read_qc_archive_h5(raw_path, "iso_17", self.energy_target_names, self.force_target_names)
samples = []
for i in range(1, 9):
raw_path = p_join(self.root, f"{i}000")
data = load_hdf5_file(raw_path)
samples += [
read_mol(data[k], k, self.energy_target_names, self.force_target_names) for k in tqdm(data.keys())
]

return samples


if __name__ == "__main__":
for data_class in [ISO17]:
for data_class in [QM7X]:
data = data_class()
n = len(data)

Expand All @@ -49,5 +73,3 @@ def read_raw_entries(self):
for k in x:
if x[k] is not None:
print(k, x[k].shape, end=" ")

print()
2 changes: 1 addition & 1 deletion src/openqdc/datasets/qmugs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def read_mol(mol_dir):
filenames = glob(p_join(mol_dir, "*.sdf"))
mols = [dm.read_sdf(f)[0] for f in filenames]
mols = [dm.read_sdf(f, remove_hs=False)[0] for f in filenames]
n_confs = len(mols)

if len(mols) == 0:
Expand Down
14 changes: 7 additions & 7 deletions src/openqdc/datasets/sn2_rxn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,33 @@


class SN2RXN(BaseDataset):
__name__ = "iso_17"
__name__ = "sn2_rxn"

# Energy in hartree, all zeros by default
atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32)

__energy_methods__ = [
"pbe-ts",
"dsd-blyp-d3(bj)_tz",
]

energy_target_names = [
"PBE-TS Energy",
"DSD-BLYP-D3(BJ):def2-TZVP Atomization Energy",
]

__force_methods__ = [
"pbe-ts",
"dsd-blyp-d3(bj)_tz",
]

force_target_names = [
"PBE-TS Gradient",
"DSD-BLYP-D3(BJ):def2-TZVP Gradient",
]

def __init__(self) -> None:
super().__init__()

def read_raw_entries(self):
raw_path = p_join(self.root, "iso_17.h5")
samples = read_qc_archive_h5(raw_path, "iso_17", self.energy_target_names, self.force_target_names)
raw_path = p_join(self.root, "sn2_rxn.h5")
samples = read_qc_archive_h5(raw_path, "sn2_rxn", self.energy_target_names, self.force_target_names)

return samples

Expand Down
Loading

0 comments on commit 035fabb

Please sign in to comment.