Skip to content

Commit

Permalink
Merge pull request #2 from OpenDrugDiscovery/pre-commit
Browse files Browse the repository at this point in the history
Added and run pre-commit hooks
  • Loading branch information
shenoynikhil authored Sep 25, 2023
2 parents 0f21262 + 7db8240 commit 8d0cbc2
Show file tree
Hide file tree
Showing 24 changed files with 397 additions and 381 deletions.
2 changes: 1 addition & 1 deletion .github/SECURITY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Security Policy

Please report any security-related issues directly to [email protected].
Please report any security-related issues directly to [email protected].
15 changes: 15 additions & 0 deletions .github/workflows/pre-commit-ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Based on https://github.com/pre-commit/action
name: pre-commit

on:
pull_request:
push:
branches: [main]

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: pre-commit/[email protected]
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,9 @@ cache/
*.hdf5
nohup.out
*.out
*.crt
*.crt
*.key
*.dat
*.xyz
*.csv
*.txt

26 changes: 26 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-yaml
- id: check-toml
- id: check-json
- id: check-merge-conflict
- id: requirements-txt-fixer
- id: detect-private-key
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version.
rev: 'v0.0.241'
hooks:
- id: ruff
2 changes: 1 addition & 1 deletion env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,6 @@ dependencies:
- ruff
- ipykernel
- pydantic <= 2.0

- pip:
- torch-nl
51 changes: 25 additions & 26 deletions openqdc/datasets/ani.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
import numpy as np
from os.path import join as p_join
from openqdc.utils.constants import MAX_ATOMIC_NUMBER

import numpy as np

from openqdc.datasets.base import BaseDataset, read_qc_archive_h5
from openqdc.utils.constants import MAX_ATOMIC_NUMBER
from openqdc.utils.io import get_local_cache


class ANI1(BaseDataset):
__name__ = 'ani1'

__name__ = "ani1"

# Energy in hartree, all zeros by default
atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32)
Expand All @@ -21,29 +22,27 @@ class ANI1(BaseDataset):
"ωB97x:6-31G(d) Energy",
]


def __init__(self) -> None:
super().__init__()

@property
def root(self):
return p_join(get_local_cache(), 'ani')
return p_join(get_local_cache(), "ani")

@property
def preprocess_path(self):
path = p_join(self.root, 'preprocessed', self.__name__)
path = p_join(self.root, "preprocessed", self.__name__)
os.makedirs(path, exist_ok=True)
return path

def read_raw_entries(self):
raw_path = p_join(self.root, f'{self.__name__}.h5')
samples = read_qc_archive_h5(raw_path, self.__name__, self.energy_target_names,
self.force_target_names)
raw_path = p_join(self.root, f"{self.__name__}.h5")
samples = read_qc_archive_h5(raw_path, self.__name__, self.energy_target_names, self.force_target_names)
return samples


class ANI1CCX(ANI1):
__name__ = 'ani1ccx'
__name__ = "ani1ccx"

# Energy in hartree, all zeros by default
atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32)
Expand All @@ -67,10 +66,10 @@ class ANI1CCX(ANI1):

def __init__(self) -> None:
super().__init__()


class ANI1X(ANI1):
__name__ = 'ani1x'
__name__ = "ani1x"

# Energy in hartree, all zeros by default
atomic_energies = np.zeros((MAX_ATOMIC_NUMBER,), dtype=np.float32)
Expand All @@ -94,10 +93,10 @@ class ANI1X(ANI1):
"MP2:cc-pVQZ Correlation Energy",
"MP2:cc-pVTZ Correlation Energy",
"wB97x:6-31G(d) Total Energy",
"wB97x:def2-TZVPP Total Energy"
"wB97x:def2-TZVPP Total Energy",
]

force_target_names = [
force_target_names = [
"wB97x:6-31G(d) Atomic Forces",
"wB97x:def2-TZVPP Atomic Forces",
]
Expand All @@ -111,21 +110,21 @@ def __init__(self) -> None:
super().__init__()


if __name__ == '__main__':
if __name__ == "__main__":
for data_class in [
ANI1,
# ANI1CCX,
# ANI1X
]:
ANI1,
# ANI1CCX,
# ANI1X
]:
data = data_class()
n = len(data)

for i in np.random.choice(n, 3, replace=False):
x = data[i]
print(x.name, x.subset, end=' ')
print(x.name, x.subset, end=" ")
for k in x:
if x[k] is not None:
print(k, x[k].shape, end=' ')
print(k, x[k].shape, end=" ")

print()
exit()
exit()
99 changes: 50 additions & 49 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,34 @@
import os
import torch
import numpy as np
import pickle as pkl
from tqdm import tqdm
from os.path import join as p_join

import numpy as np
import torch
from sklearn.utils import Bunch
from openqdc.utils.io import get_local_cache, pull_locally, push_remote, load_hdf5_file, copy_exists
from tqdm import tqdm

from openqdc.utils.constants import NB_ATOMIC_FEATURES
from openqdc.utils.io import (
copy_exists,
get_local_cache,
load_hdf5_file,
pull_locally,
push_remote,
)
from openqdc.utils.molecule import atom_table
from openqdc.utils.constants import BOHR2ANG, MAX_ATOMIC_NUMBER, NB_ATOMIC_FEATURES


def extract_entry(df, i, subset, energy_target_names, force_target_names=None):
def extract_entry(df, i, subset, energy_target_names, force_target_names=None):
x = np.array([atom_table.GetAtomicNumber(s) for s in df["symbols"][i]])
xs = np.stack((x, np.zeros_like(x)), axis=-1)
positions= df["geometry"][i].reshape((-1, 3))
energies= np.array([df[k][i] for k in energy_target_names])
positions = df["geometry"][i].reshape((-1, 3))
energies = np.array([df[k][i] for k in energy_target_names])

res = dict(
name= np.array([df["name"][i]]),
subset= np.array([subset]),
energies= energies.reshape((1, -1)).astype(np.float32),
atomic_inputs = np.concatenate((xs, positions), axis=-1, dtype=np.float32),
n_atoms = np.array([x.shape[0]], dtype=np.int32),
name=np.array([df["name"][i]]),
subset=np.array([subset]),
energies=energies.reshape((1, -1)).astype(np.float32),
atomic_inputs=np.concatenate((xs, positions), axis=-1, dtype=np.float32),
n_atoms=np.array([x.shape[0]], dtype=np.int32),
)
if force_target_names is not None and len(force_target_names) > 0:
forces = np.zeros((positions.shape[0], 3, len(force_target_names)), dtype=np.float32)
Expand All @@ -47,8 +54,7 @@ def read_qc_archive_h5(raw_path, subset, energy_target_names, force_target_names
# print('\n'*3)
# exit()

samples = [extract_entry(data_t, i, subset, energy_target_names, force_target_names)
for i in tqdm(range(n))]
samples = [extract_entry(data_t, i, subset, energy_target_names, force_target_names) for i in tqdm(range(n))]
return samples


Expand All @@ -71,45 +77,44 @@ def __init__(self) -> None:
@property
def root(self):
return p_join(get_local_cache(), self.__name__)

@property
def preprocess_path(self):
path = p_join(self.root, 'preprocessed')
path = p_join(self.root, "preprocessed")
os.makedirs(path, exist_ok=True)
return path

@property
def data_keys(self):
keys = list(self.data_types.keys())
if len(self.__force_methods__) == 0:
keys.remove("forces")
return keys

@property
def data_types(self):
return {
"atomic_inputs": np.float32,
"position_idx_range": np.int32,
"atomic_inputs": np.float32,
"position_idx_range": np.int32,
"energies": np.float32,
"forces": np.float32
"forces": np.float32,
}

@property
def data_shapes(self):
return {
"atomic_inputs": (-1, NB_ATOMIC_FEATURES),
"position_idx_range": (-1, 2),
"atomic_inputs": (-1, NB_ATOMIC_FEATURES),
"position_idx_range": (-1, 2),
"energies": (-1, len(self.energy_target_names)),
"forces": (-1, 3, len(self.force_target_names))
"forces": (-1, 3, len(self.force_target_names)),
}

def read_raw_entries(self):
raise NotImplementedError

def collate_list(self, list_entries):
# concatenate entries
res = {key: np.concatenate([r[key] for r in list_entries if r is not None], axis=0)
for key in list_entries[0]}
res = {key: np.concatenate([r[key] for r in list_entries if r is not None], axis=0) for key in list_entries[0]}

csum = np.cumsum(res.pop("n_atoms"))
x = np.zeros((csum.shape[0], 2), dtype=np.int32)
Expand All @@ -121,52 +126,48 @@ def save_preprocess(self, data_dict):
# save memmaps
for key in self.data_keys:
local_path = p_join(self.preprocess_path, f"{key}.mmap")
out = np.memmap(local_path,
mode="w+",
dtype=data_dict[key].dtype,
shape=data_dict[key].shape)
out = np.memmap(local_path, mode="w+", dtype=data_dict[key].dtype, shape=data_dict[key].shape)
out[:] = data_dict.pop(key)[:]
out.flush()
push_remote(local_path)

# save smiles and subset
for key in ["name", "subset"]:
local_path = p_join(self.preprocess_path, f"{key}.npz")
uniques, inv_indices = np.unique(data_dict[key], return_inverse=True)
with open(local_path, "wb") as f:
np.savez_compressed(f, uniques=uniques, inv_indices=inv_indices)
push_remote(local_path)
def read_preprocess(self):

def read_preprocess(self):
self.data = {}
for key in self.data_keys:
filename = p_join(self.preprocess_path, f"{key}.mmap")
pull_locally(filename)
self.data[key] = np.memmap(
filename, mode='r',
filename,
mode="r",
dtype=self.data_types[key],
).reshape(self.data_shapes[key])

for key in self.data:
print(f'Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}')
print(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}")

for key in ["name", "subset"]:
filename = p_join(self.preprocess_path, f"{key}.npz")
pull_locally(filename)
# with open(filename, "rb") as f:
self.data[key] = np.load(open(filename, "rb"))
for k in self.data[key]:
print(f'Loaded {key}_{k} with shape {self.data[key][k].shape}, dtype {self.data[key][k].dtype}')
print(f"Loaded {key}_{k} with shape {self.data[key][k].shape}, dtype {self.data[key][k].dtype}")

def is_preprocessed(self):
predicats = [copy_exists(p_join(self.preprocess_path, f"{key}.mmap"))
for key in self.data_keys]
predicats += [copy_exists(p_join(self.preprocess_path, f"{x}.npz"))
for x in ["name", "subset"]]
predicats = [copy_exists(p_join(self.preprocess_path, f"{key}.mmap")) for key in self.data_keys]
predicats += [copy_exists(p_join(self.preprocess_path, f"{x}.npz")) for x in ["name", "subset"]]
return all(predicats)

def __len__(self):
return self.data['energies'].shape[0]
return self.data["energies"].shape[0]

def __getitem__(self, idx: int):
p_start, p_end = self.data["position_idx_range"][idx]
Expand All @@ -190,5 +191,5 @@ def __getitem__(self, idx: int):
energies=energies,
name=name,
subset=subset,
forces=forces
forces=forces,
)
Loading

0 comments on commit 8d0cbc2

Please sign in to comment.