Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/release' into testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikhil Shenoy committed Apr 5, 2024
2 parents 7ffd0b1 + a71f4d7 commit d15e9cf
Show file tree
Hide file tree
Showing 19 changed files with 266 additions and 95 deletions.
4 changes: 4 additions & 0 deletions openqdc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
# The below lazy import logic is coming from openff-toolkit:
# https://github.com/openforcefield/openff-toolkit/blob/b52879569a0344878c40248ceb3bd0f90348076a/openff/toolkit/__init__.py#L44


# Dictionary of objects to lazily import; maps the object's name to its module path
def get_project_root():
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


_lazy_imports_obj = {
"__version__": "openqdc._version",
Expand Down
57 changes: 26 additions & 31 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import pickle as pkl
from copy import deepcopy
from functools import partial
from itertools import compress
from os.path import join as p_join
Expand Down Expand Up @@ -71,13 +70,13 @@ class BaseDataset(DatasetPropertyMixIn):
__energy_methods__ = []
__force_mask__ = []
__isolated_atom_energies__ = []
_fn_energy = lambda x: x
_fn_distance = lambda x: x
_fn_forces = lambda x: x

__energy_unit__ = "hartree"
__distance_unit__ = "ang"
__forces_unit__ = "hartree/ang"
__fn_energy__ = lambda x: x
__fn_distance__ = lambda x: x
__fn_forces__ = lambda x: x
__average_nb_atoms__ = None

def __init__(
Expand Down Expand Up @@ -123,6 +122,7 @@ def __init__(
solver_type can be one of ["linear", "ridge"]
"""
set_cache_dir(cache_dir)
# self._init_lambda_fn()
self.data = None
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
Expand All @@ -136,6 +136,11 @@ def __init__(
self.set_array_format(array_format)
self._post_init(overwrite_local_cache, energy_unit, distance_unit)

def _init_lambda_fn(self):
self._fn_energy = lambda x: x
self._fn_distance = lambda x: x
self._fn_forces = lambda x: x

def _post_init(
self,
overwrite_local_cache: bool = False,
Expand Down Expand Up @@ -272,7 +277,7 @@ def _set_units(self, en, ds):
self.set_distance_unit(ds)
if self.__force_methods__:
self.__forces_unit__ = self.energy_unit + "/" + self.distance_unit
self.__class__.__fn_forces__ = get_conversion(old_en + "/" + old_ds, self.__forces_unit__)
self._fn_forces = get_conversion(old_en + "/" + old_ds, self.__forces_unit__)

def _set_isolated_atom_energies(self):
if self.__energy_methods__ is None:
Expand All @@ -281,29 +286,29 @@ def _set_isolated_atom_energies(self):
self.__isolated_atom_energies__ = f(self.e0s_dispatcher.e0s_matrix)

def convert_energy(self, x):
return self.__class__.__fn_energy__(x)
return self._fn_energy(x)

def convert_distance(self, x):
return self.__class__.__fn_distance__(x)
return self._fn_distance(x)

def convert_forces(self, x):
return self.__class__.__fn_forces__(x)
return self._fn_forces(x)

def set_energy_unit(self, value: str):
"""
Set a new energy unit for the dataset.
"""
old_unit = self.energy_unit
self.__energy_unit__ = value
self.__class__.__fn_energy__ = get_conversion(old_unit, value)
self._fn_energy = get_conversion(old_unit, value)

def set_distance_unit(self, value: str):
"""
Set a new distance unit for the dataset.
"""
old_unit = self.distance_unit
self.__distance_unit__ = value
self.__class__.__fn_distance__ = get_conversion(old_unit, value)
self._fn_distance = get_conversion(old_unit, value)

def set_array_format(self, format: str):
assert format in ["numpy", "torch", "jax"], f"Format {format} not supported."
Expand Down Expand Up @@ -535,39 +540,29 @@ def get_statistics(self, return_none: bool = True):
Whether to return None if the statistics for the forces are not available, by default True
Otherwise, the statistics for the forces are set to 0.0
"""
stats = deepcopy(self.statistics.get_results())
if len(stats) == 0:
selected_stats = self.statistics.get_results()
if len(selected_stats) == 0:
raise StatisticsNotAvailableError(self.__name__)
selected_stats = stats
if not return_none:
selected_stats.update(
{
"ForcesCalculatorStats": {
"mean": np.array([0.0]),
"std": np.array([0.0]),
"components": {
"mean": np.array([[0.0], [0.0], [0.0]]),
"std": np.array([[0.0], [0.0], [0.0]]),
"rms": np.array([[0.0], [0.0], [0.0]]),
},
"component_mean": np.array([[0.0], [0.0], [0.0]]),
"component_std": np.array([[0.0], [0.0], [0.0]]),
"component_rms": np.array([[0.0], [0.0], [0.0]]),
}
}
)
# cycle trough dict to convert units
for key in selected_stats:
if key.lower() == str(ForcesCalculatorStats):
for key2 in selected_stats[key]:
if key2 != "components":
selected_stats[key][key2] = self.convert_forces(selected_stats[key][key2])
else:
for key2 in selected_stats[key]["components"]:
selected_stats[key]["components"][key2] = self.convert_forces(
selected_stats[key]["components"][key2]
)
for key, result in selected_stats.items():
if isinstance(result, ForcesCalculatorStats):
result.transform(self.convert_forces)
else:
for key2 in selected_stats[key]:
selected_stats[key][key2] = self.convert_energy(selected_stats[key][key2])
return selected_stats
result.transform(self.convert_energy)
result.transform(self._convert_array)
return {k: result.to_dict() for k, result in selected_stats.items()}

def __str__(self):
return f"{self.__name__}"
Expand Down
5 changes: 5 additions & 0 deletions openqdc/datasets/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ def __init__(
self.__energy_unit__ = energy_unit
self.__distance_unit__ = distance_unit
self.__energy_methods__ = [PotentialMethod.NONE if not level_of_theory else level_of_theory]
self.energy_target_names = ["xyz"]
self.regressor_kwargs = regressor_kwargs
self.transform = transform
self._read_and_preprocess()
if "forces" in self.data:
self.__force_mask__ = [True]
self.__class__.__force_methods__ = [level_of_theory]
self.force_target_names = ["xyz"]
self.set_array_format(array_format)
self._post_init(True, energy_unit, distance_unit)

Expand Down
80 changes: 58 additions & 22 deletions openqdc/datasets/potential/dummy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pickle as pkl
from os.path import join as p_join

import numpy as np
from loguru import logger

from openqdc.datasets.base import BaseDataset
from openqdc.methods import PotentialMethod
from openqdc.utils.constants import NOT_DEFINED


class Dummy(BaseDataset):
Expand All @@ -11,8 +14,8 @@ class Dummy(BaseDataset):
"""

__name__ = "dummy"
__energy_methods__ = [PotentialMethod.SVWN_DEF2_TZVP, PotentialMethod.PM6]
__force_mask__ = [False, True]
__energy_methods__ = [PotentialMethod.SVWN_DEF2_TZVP, PotentialMethod.PM6, PotentialMethod.GFN2_XTB]
__force_mask__ = [False, True, True]
__energy_unit__ = "kcal/mol"
__distance_unit__ = "ang"
__forces_unit__ = "kcal/mol/ang"
Expand All @@ -23,25 +26,6 @@ class Dummy(BaseDataset):
__isolated_atom_energies__ = []
__average_n_atoms__ = None

@property
def _stats(self):
return {
"formation": {
"energy": {
"mean": np.array([[-12.94348027, -9.83037297]]),
"std": np.array([[4.39971409, 3.3574188]]),
},
"forces": NOT_DEFINED,
},
"total": {
"energy": {
"mean": np.array([[-89.44242, -1740.5336]]),
"std": np.array([[29.599571, 791.48663]]),
},
"forces": NOT_DEFINED,
},
}

def _post_init(self, overwrite_local_cache, energy_unit, distance_unit) -> None:
self.setup_dummy()
return super()._post_init(overwrite_local_cache, energy_unit, distance_unit)
Expand Down Expand Up @@ -90,3 +74,55 @@ def read_raw_entries(self):

def __len__(self):
return 9999


class PredefinedDataset(BaseDataset):
__name__ = "predefineddataset"
__energy_methods__ = [PotentialMethod.WB97M_D3BJ_DEF2_TZVPPD] # "wb97m-d3bj/def2-tzvppd"]
__force_mask__ = [True]
__energy_unit__ = "hartree"
__distance_unit__ = "bohr"
__forces_unit__ = "hartree/bohr"
force_target_names = __energy_methods__
energy_target_names = __energy_methods__

@property
def preprocess_path(self, overwrite_local_cache=False):
from os.path import join as p_join

from openqdc import get_project_root

return p_join(get_project_root(), "tests", "files", self.__name__, "preprocessed")

def is_preprocessed(self):
return True

def read_raw_entries(self):
pass

def read_preprocess(self, overwrite_local_cache=False):
logger.info("Reading preprocessed data.")
logger.info(
f"Dataset {self.__name__} with the following units:\n\
Energy: {self.energy_unit},\n\
Distance: {self.distance_unit},\n\
Forces: {self.force_unit if self.force_methods else 'None'}"
)
self.data = {}
for key in self.data_keys:
print(key, self.data_shapes[key], self.data_types[key])
filename = p_join(self.preprocess_path, f"{key}.mmap")
self.data[key] = np.memmap(filename, mode="r", dtype=self.data_types[key]).reshape(*self.data_shapes[key])

filename = p_join(self.preprocess_path, "props.pkl")
with open(filename, "rb") as f:
tmp = pkl.load(f)
for key in ["name", "subset", "n_atoms"]:
x = tmp.pop(key)
if len(x) == 2:
self.data[key] = x[0][x[1]]
else:
self.data[key] = x

for key in self.data:
logger.info(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}")
Loading

0 comments on commit d15e9cf

Please sign in to comment.