Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Interaction and Better Testing #71

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
31beb71
refactor interaction and initial testing
Mar 26, 2024
dccf676
minor changes
Mar 26, 2024
2ab64aa
dummy modification
Mar 26, 2024
189ab90
undo changes in interaction dataset, and minor change in shape
Mar 29, 2024
282dc91
changed super class to BaseInteractionDataset
Apr 2, 2024
701ef1e
Merge branch 'release' into testing
Apr 3, 2024
afea053
further simplified and rebase
Apr 3, 2024
ebc2adf
fixes
Apr 5, 2024
7ffd0b1
Merge remote-tracking branch 'origin/release' into testing
Apr 5, 2024
d15e9cf
Merge remote-tracking branch 'origin/release' into testing
Apr 5, 2024
ed8e264
Updated metcalf
Apr 5, 2024
18bc79c
bug fix and simplifying interaction dataset
Apr 6, 2024
2a6e3ef
Updated tests for interaction datasets
Apr 6, 2024
7493273
removed stale stats in dummy interaction
Apr 6, 2024
ed73e7d
changes based on comments
Apr 6, 2024
0359022
Clean metcalf
FNTwin Apr 6, 2024
33fa342
Simplification
FNTwin Apr 6, 2024
cd486a8
cleaned des
FNTwin Apr 6, 2024
80d7371
Simplified des dataset
FNTwin Apr 6, 2024
f3d205c
removed redundant dataset files
FNTwin Apr 6, 2024
da4fece
DES inerithance
FNTwin Apr 6, 2024
71ff741
Removed des and improved des naming
FNTwin Apr 6, 2024
f6e12e1
DES fixes
FNTwin Apr 6, 2024
3328a65
Removed comments
FNTwin Apr 6, 2024
8b28d59
X40 and L70
FNTwin Apr 6, 2024
8595fd8
Safe opening
FNTwin Apr 6, 2024
ca1b4af
Moved X40 in L7 and removed x40.py
FNTwin Apr 6, 2024
4bec82d
Moved Yaml utils to _utils.py, L7 + X40 interface
FNTwin Apr 7, 2024
a5ced0a
Merge testing + Add imports
FNTwin Apr 8, 2024
a21963e
Merge pull request #79 from OpenDrugDiscovery/interaction_impr
shenoynikhil Apr 8, 2024
3303f95
better convert function and n_body_first to ptr
Apr 12, 2024
6f033cf
Updated splinter reading from -1 to nan
Apr 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,14 @@ def data_keys(self):
keys.remove("forces")
return keys

@property
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
def pkl_data_keys(self):
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
return list(self.pkl_data_types.keys())

@property
def pkl_data_types(self):
return {"name": str, "subset": str, "n_atoms": np.int32}

@property
def data_types(self):
return {
Expand All @@ -257,8 +265,8 @@ def data_shapes(self):
return {
"atomic_inputs": (-1, NB_ATOMIC_FEATURES),
"position_idx_range": (-1, 2),
"energies": (-1, len(self.energy_target_names)),
"forces": (-1, 3, len(self.force_target_names)),
"energies": (-1, len(self.energy_methods)),
"forces": (-1, 3, len(self.force_methods)),
}

def _set_units(self, en, ds):
Expand Down Expand Up @@ -332,8 +340,14 @@ def save_preprocess(self, data_dict):

# save smiles and subset
local_path = p_join(self.preprocess_path, "props.pkl")
for key in ["name", "subset"]:
data_dict[key] = np.unique(data_dict[key], return_inverse=True)

# assert that (required) pkl keys are present in data_dict
assert all([key in data_dict.keys() for key in self.pkl_data_keys])

# store unique and inverse indices for str-based pkl keys
for key in self.pkl_data_keys:
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
if self.pkl_data_types[key] == str:
data_dict[key] = np.unique(data_dict[key], return_inverse=True)

with open(local_path, "wb") as f:
pkl.dump(data_dict, f)
Expand Down Expand Up @@ -369,7 +383,10 @@ def read_preprocess(self, overwrite_local_cache=False):
pull_locally(filename, overwrite=overwrite_local_cache)
with open(filename, "rb") as f:
tmp = pkl.load(f)
for key in ["name", "subset", "n_atoms"]:
all_pkl_keys = set(tmp.keys()) - set(self.data_keys)
# assert required pkl_keys are present in all_pkl_keys
assert all([key in all_pkl_keys for key in self.pkl_data_keys])
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
for key in all_pkl_keys:
x = tmp.pop(key)
if len(x) == 2:
self.data[key] = x[0][x[1]]
Expand Down
79 changes: 44 additions & 35 deletions openqdc/datasets/interaction/L7.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from typing import Dict, List
from dataclasses import dataclass
from functools import partial
from typing import Dict, List, Optional

import numpy as np
import yaml
Expand All @@ -10,42 +12,49 @@
from openqdc.utils.constants import ATOM_TABLE


class DataItemYAMLObj:
def __init__(self, name, shortname, geometry, reference_value, setup, group, tags):
self.name = name
self.shortname = shortname
self.geometry = geometry
self.reference_value = reference_value
self.setup = setup
self.group = group
self.tags = tags


class DataSetYAMLObj:
def __init__(self, name, references, text, method_energy, groups_by, groups, global_setup, method_geometry=None):
self.name = name
self.references = references
self.text = text
self.method_energy = method_energy
self.method_geometry = method_geometry
self.groups_by = groups_by
self.groups = groups
self.global_setup = global_setup


def data_item_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode):
return DataItemYAMLObj(**loader.construct_mapping(node))
@dataclass
class DataSet:
description: Dict
items: List[Dict]
alternative_reference: Dict


def dataset_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode):
return DataSetYAMLObj(**loader.construct_mapping(node))
@dataclass
class DataItemYAMLObj:
name: str
shortname: str
geometry: str
reference_value: float
setup: Dict
group: str
tags: str


@dataclass
class DataSetDescription:
name: Dict
references: str
text: str
groups_by: str
groups: List[str]
global_setup: Dict
method_energy: str
method_geometry: Optional[str] = None


def get_loader():
"""Add constructors to PyYAML loader."""

def constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode, cls):
return cls(**loader.construct_mapping(node))

loader = yaml.SafeLoader
loader.add_constructor("!ruby/object:ProtocolDataset::DataSetItem", data_item_constructor)
loader.add_constructor("!ruby/object:ProtocolDataset::DataSetDescription", dataset_constructor)

loader.add_constructor("!ruby/object:ProtocolDataset::DataSet", partial(constructor, cls=DataSet))
loader.add_constructor("!ruby/object:ProtocolDataset::DataSetItem", partial(constructor, cls=DataItemYAMLObj))
loader.add_constructor(
"!ruby/object:ProtocolDataset::DataSetDescription", partial(constructor, cls=DataSetDescription)
)
return loader


Expand All @@ -62,7 +71,7 @@ class L7(BaseInteractionDataset):
http://cuby4.molecular.cz/dataset_l7.html
"""

__name__ = "L7"
__name__ = "l7"
__energy_unit__ = "kcal/mol"
__distance_unit__ = "ang"
__forces_unit__ = "kcal/mol/ang"
Expand All @@ -87,10 +96,10 @@ def read_raw_entries(self) -> List[Dict]:
yaml_file = open(yaml_fpath, "r")
data = []
data_dict = yaml.load(yaml_file, Loader=get_loader())
charge0 = int(data_dict["description"].global_setup["molecule_a"]["charge"])
charge1 = int(data_dict["description"].global_setup["molecule_b"]["charge"])
charge0 = int(data_dict.description.global_setup["molecule_a"]["charge"])
charge1 = int(data_dict.description.global_setup["molecule_b"]["charge"])

for idx, item in enumerate(data_dict["items"]):
for idx, item in enumerate(data_dict.items):
energies = []
name = np.array([item.shortname])
fname = item.geometry.split(":")[1]
Expand All @@ -101,7 +110,7 @@ def read_raw_entries(self) -> List[Dict]:
n_atoms = np.array([int(lines[0][0])], dtype=np.int32)
n_atoms_first = np.array([int(item.setup["molecule_a"]["selection"].split("-")[1])], dtype=np.int32)
subset = np.array([item.group])
energies += [float(val[idx]) for val in list(data_dict["alternative_reference"].values())]
energies += [float(val[idx]) for val in list(data_dict.alternative_reference.values())]
energies = np.array([energies], dtype=np.float32)
pos = np.array(lines[1:])[:, 1:].astype(np.float32)
elems = np.array(lines[1:])[:, 0]
Expand Down
10 changes: 5 additions & 5 deletions openqdc/datasets/interaction/X40.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class X40(BaseInteractionDataset):
http://cuby4.molecular.cz/dataset_x40.html
"""

__name__ = "X40"
__name__ = "x40"
__energy_unit__ = "hartree"
__distance_unit__ = "ang"
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
__forces_unit__ = "hartree/ang"
Expand All @@ -48,10 +48,10 @@ def read_raw_entries(self) -> List[Dict]:
yaml_file = open(yaml_fpath, "r")
data = []
data_dict = yaml.load(yaml_file, Loader=get_loader())
charge0 = int(data_dict["description"].global_setup["molecule_a"]["charge"])
charge1 = int(data_dict["description"].global_setup["molecule_b"]["charge"])
charge0 = int(data_dict.description.global_setup["molecule_a"]["charge"])
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
charge1 = int(data_dict.description.global_setup["molecule_b"]["charge"])

for idx, item in enumerate(data_dict["items"]):
for idx, item in enumerate(data_dict.items):
energies = []
name = np.array([item.shortname])
energies.append(float(item.reference_value))
Expand All @@ -62,7 +62,7 @@ def read_raw_entries(self) -> List[Dict]:
n_atoms_first = setup[0].split("-")[1]
n_atoms_first = np.array([int(n_atoms_first)], dtype=np.int32)
subset = np.array([item.group])
energies += [float(val[idx]) for val in list(data_dict["alternative_reference"].values())]
energies += [float(val[idx]) for val in list(data_dict.alternative_reference.values())]
energies = np.array([energies], dtype=np.float32)
pos = np.array(lines[1:])[:, 1:].astype(np.float32)
elems = np.array(lines[1:])[:, 0]
Expand Down
95 changes: 11 additions & 84 deletions openqdc/datasets/interaction/base.py
Copy link
Collaborator

@FNTwin FNTwin Apr 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently trying to load any interaction dataset will get you an error due to the:
if not self.is_preprocessed() failing due to the naming.

In the bucket they were written L7 and X40 (upper case). We should always have the sanitize name on lower case. As we need to postprocess it again to have the new keys. It will fix itself

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, need to push new changes.

Original file line number Diff line number Diff line change
@@ -1,52 +1,26 @@
import os
import pickle as pkl
from os.path import join as p_join
from typing import Dict, List, Optional
from typing import Optional

import numpy as np
from ase.io.extxyz import write_extxyz
from loguru import logger
from sklearn.utils import Bunch

from openqdc.datasets.base import BaseDataset
from openqdc.utils.constants import MAX_CHARGE, NB_ATOMIC_FEATURES
from openqdc.utils.io import pull_locally, push_remote, to_atoms
from openqdc.utils.constants import MAX_CHARGE
from openqdc.utils.io import to_atoms


class BaseInteractionDataset(BaseDataset):
__energy_type__ = []

def collate_list(self, list_entries: List[Dict]):
# concatenate entries
res = {
key: np.concatenate([r[key] for r in list_entries if r is not None], axis=0)
for key in list_entries[0]
if not isinstance(list_entries[0][key], dict)
}

csum = np.cumsum(res.get("n_atoms"))
x = np.zeros((csum.shape[0], 2), dtype=np.int32)
x[1:, 0], x[:, 1] = csum[:-1], csum
res["position_idx_range"] = x

return res

@property
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
def data_shapes(self):
return {
"atomic_inputs": (-1, NB_ATOMIC_FEATURES),
"position_idx_range": (-1, 2),
"energies": (-1, len(self.__energy_methods__)),
"forces": (-1, 3, len(self.force_target_names)),
}

@property
def data_types(self):
def pkl_data_types(self):
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
return {
"atomic_inputs": np.float32,
"position_idx_range": np.int32,
"energies": np.float32,
"forces": np.float32,
"name": str,
"subset": str,
"n_atoms": np.int32,
"n_atoms_first": np.int32,
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
}

def __getitem__(self, idx: int):
Expand All @@ -68,13 +42,16 @@ def __getitem__(self, idx: int):
forces = self._convert_array(np.array(self.data["forces"][p_start:p_end], dtype=np.float32))

e0 = self._convert_array(np.array(self.__isolated_atom_energies__[..., z, c + shift].T, dtype=np.float32))
formation_energies = energies - e0.sum(axis=0)
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved

bunch = Bunch(
positions=positions,
atomic_numbers=z,
charges=c,
e0=e0,
energies=energies,
formation_energies=formation_energies,
per_atom_formation_energies=formation_energies / len(z),
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
name=name,
subset=subset,
forces=forces,
Expand All @@ -86,56 +63,6 @@ def __getitem__(self, idx: int):

return bunch

def save_preprocess(self, data_dict):
# save memmaps
logger.info("Preprocessing data and saving it to cache.")
for key in self.data_keys:
local_path = p_join(self.preprocess_path, f"{key}.mmap")
out = np.memmap(local_path, mode="w+", dtype=data_dict[key].dtype, shape=data_dict[key].shape)
out[:] = data_dict.pop(key)[:]
out.flush()
push_remote(local_path, overwrite=True)

# save all other keys in props.pkl
local_path = p_join(self.preprocess_path, "props.pkl")
for key in data_dict:
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
if key not in self.data_keys:
x = data_dict[key]
x[x == None] = -1 # noqa
data_dict[key] = np.unique(x, return_inverse=True)

with open(local_path, "wb") as f:
pkl.dump(data_dict, f)
push_remote(local_path, overwrite=True)

def read_preprocess(self, overwrite_local_cache=False):
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
logger.info("Reading preprocessed data.")
logger.info(
f"Dataset {self.__name__} with the following units:\n\
Energy: {self.energy_unit},\n\
Distance: {self.distance_unit},\n\
Forces: {self.force_unit if self.__force_methods__ else 'None'}"
)
self.data = {}
for key in self.data_keys:
filename = p_join(self.preprocess_path, f"{key}.mmap")
pull_locally(filename, overwrite=overwrite_local_cache)
self.data[key] = np.memmap(filename, mode="r", dtype=self.data_types[key]).reshape(self.data_shapes[key])

filename = p_join(self.preprocess_path, "props.pkl")
pull_locally(filename, overwrite=overwrite_local_cache)
with open(filename, "rb") as f:
tmp = pkl.load(f)
for key in set(tmp.keys()) - set(self.data_keys):
x = tmp.pop(key)
if len(x) == 2:
self.data[key] = x[0][x[1]]
else:
self.data[key] = x

for key in self.data:
logger.info(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}")

def get_ase_atoms(self, idx: int):
entry = self[idx]
at = to_atoms(entry["positions"], entry["atomic_numbers"])
shenoynikhil marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Loading
Loading