Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug Fixes and Component-wise Force Simplification #78

Merged
merged 7 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions openqdc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
# The below lazy import logic is coming from openff-toolkit:
# https://github.com/openforcefield/openff-toolkit/blob/b52879569a0344878c40248ceb3bd0f90348076a/openff/toolkit/__init__.py#L44


# Dictionary of objects to lazily import; maps the object's name to its module path
def get_project_root():
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


_lazy_imports_obj = {
"__version__": "openqdc._version",
Expand Down
57 changes: 26 additions & 31 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import os
import pickle as pkl
from copy import deepcopy
from functools import partial
from itertools import compress
from os.path import join as p_join
Expand Down Expand Up @@ -71,13 +70,13 @@ class BaseDataset(DatasetPropertyMixIn):
__energy_methods__ = []
__force_mask__ = []
__isolated_atom_energies__ = []
_fn_energy = lambda x: x
_fn_distance = lambda x: x
_fn_forces = lambda x: x

__energy_unit__ = "hartree"
__distance_unit__ = "ang"
__forces_unit__ = "hartree/ang"
__fn_energy__ = lambda x: x
__fn_distance__ = lambda x: x
__fn_forces__ = lambda x: x
__average_nb_atoms__ = None

def __init__(
Expand Down Expand Up @@ -123,6 +122,7 @@ def __init__(
solver_type can be one of ["linear", "ridge"]
"""
set_cache_dir(cache_dir)
# self._init_lambda_fn()
self.data = None
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
Expand All @@ -136,6 +136,11 @@ def __init__(
self.set_array_format(array_format)
self._post_init(overwrite_local_cache, energy_unit, distance_unit)

def _init_lambda_fn(self):
self._fn_energy = lambda x: x
self._fn_distance = lambda x: x
self._fn_forces = lambda x: x

def _post_init(
self,
overwrite_local_cache: bool = False,
Expand Down Expand Up @@ -264,7 +269,7 @@ def _set_units(self, en, ds):
self.set_distance_unit(ds)
if self.__force_methods__:
self.__forces_unit__ = self.energy_unit + "/" + self.distance_unit
self.__class__.__fn_forces__ = get_conversion(old_en + "/" + old_ds, self.__forces_unit__)
self._fn_forces = get_conversion(old_en + "/" + old_ds, self.__forces_unit__)

def _set_isolated_atom_energies(self):
if self.__energy_methods__ is None:
Expand All @@ -273,29 +278,29 @@ def _set_isolated_atom_energies(self):
self.__isolated_atom_energies__ = f(self.e0s_dispatcher.e0s_matrix)

def convert_energy(self, x):
return self.__class__.__fn_energy__(x)
return self._fn_energy(x)

def convert_distance(self, x):
return self.__class__.__fn_distance__(x)
return self._fn_distance(x)

def convert_forces(self, x):
return self.__class__.__fn_forces__(x)
return self._fn_forces(x)

def set_energy_unit(self, value: str):
"""
Set a new energy unit for the dataset.
"""
old_unit = self.energy_unit
self.__energy_unit__ = value
self.__class__.__fn_energy__ = get_conversion(old_unit, value)
self._fn_energy = get_conversion(old_unit, value)

def set_distance_unit(self, value: str):
"""
Set a new distance unit for the dataset.
"""
old_unit = self.distance_unit
self.__distance_unit__ = value
self.__class__.__fn_distance__ = get_conversion(old_unit, value)
self._fn_distance = get_conversion(old_unit, value)

def set_array_format(self, format: str):
assert format in ["numpy", "torch", "jax"], f"Format {format} not supported."
Expand Down Expand Up @@ -518,39 +523,29 @@ def get_statistics(self, return_none: bool = True):
Whether to return None if the statistics for the forces are not available, by default True
Otherwise, the statistics for the forces are set to 0.0
"""
stats = deepcopy(self.statistics.get_results())
if len(stats) == 0:
selected_stats = self.statistics.get_results()
if len(selected_stats) == 0:
raise StatisticsNotAvailableError(self.__name__)
selected_stats = stats
if not return_none:
selected_stats.update(
{
"ForcesCalculatorStats": {
"mean": np.array([0.0]),
"std": np.array([0.0]),
"components": {
"mean": np.array([[0.0], [0.0], [0.0]]),
"std": np.array([[0.0], [0.0], [0.0]]),
"rms": np.array([[0.0], [0.0], [0.0]]),
},
"component_mean": np.array([[0.0], [0.0], [0.0]]),
"component_std": np.array([[0.0], [0.0], [0.0]]),
"component_rms": np.array([[0.0], [0.0], [0.0]]),
}
}
)
# cycle trough dict to convert units
for key in selected_stats:
if key.lower() == str(ForcesCalculatorStats):
for key2 in selected_stats[key]:
if key2 != "components":
selected_stats[key][key2] = self.convert_forces(selected_stats[key][key2])
else:
for key2 in selected_stats[key]["components"]:
selected_stats[key]["components"][key2] = self.convert_forces(
selected_stats[key]["components"][key2]
)
for key, result in selected_stats.items():
if isinstance(result, ForcesCalculatorStats):
result.transform(self.convert_forces)
else:
for key2 in selected_stats[key]:
selected_stats[key][key2] = self.convert_energy(selected_stats[key][key2])
return selected_stats
result.transform(self.convert_energy)
result.transform(self._convert_array)
return {k: result.to_dict() for k, result in selected_stats.items()}

def __str__(self):
return f"{self.__name__}"
Expand Down
7 changes: 3 additions & 4 deletions openqdc/datasets/interaction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,11 @@ def __getitem__(self, idx: int):
subset = self.data["subset"][idx]
n_atoms_first = self.data["n_atoms_first"][idx]

forces = None
if "forces" in self.data:
forces = self._convert_array(np.array(self.data["forces"][p_start:p_end]), dtype=np.float32)
else:
forces = None
forces = self._convert_array(np.array(self.data["forces"][p_start:p_end], dtype=np.float32))

e0 = self._convert_array(self.__isolated_atom_energies__[..., z, c + shift].T, dtype=np.float32)
e0 = self._convert_array(np.array(self.__isolated_atom_energies__[..., z, c + shift].T, dtype=np.float32))

bunch = Bunch(
positions=positions,
Expand Down
5 changes: 5 additions & 0 deletions openqdc/datasets/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ def __init__(
self.__energy_unit__ = energy_unit
self.__distance_unit__ = distance_unit
self.__energy_methods__ = [PotentialMethod.NONE if not level_of_theory else level_of_theory]
self.energy_target_names = ["xyz"]
self.regressor_kwargs = regressor_kwargs
self.transform = transform
self._read_and_preprocess()
if "forces" in self.data:
self.__force_mask__ = [True]
self.__class__.__force_methods__ = [level_of_theory]
self.force_target_names = ["xyz"]
self.set_array_format(array_format)
self._post_init(True, energy_unit, distance_unit)

Expand Down
80 changes: 58 additions & 22 deletions openqdc/datasets/potential/dummy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pickle as pkl
from os.path import join as p_join

import numpy as np
from loguru import logger

from openqdc.datasets.base import BaseDataset
from openqdc.methods import PotentialMethod
from openqdc.utils.constants import NOT_DEFINED


class Dummy(BaseDataset):
Expand All @@ -11,8 +14,8 @@ class Dummy(BaseDataset):
"""

__name__ = "dummy"
__energy_methods__ = [PotentialMethod.SVWN_DEF2_TZVP, PotentialMethod.PM6]
__force_mask__ = [False, True]
__energy_methods__ = [PotentialMethod.SVWN_DEF2_TZVP, PotentialMethod.PM6, PotentialMethod.GFN2_XTB]
__force_mask__ = [False, True, True]
__energy_unit__ = "kcal/mol"
__distance_unit__ = "ang"
__forces_unit__ = "kcal/mol/ang"
Expand All @@ -23,25 +26,6 @@ class Dummy(BaseDataset):
__isolated_atom_energies__ = []
__average_n_atoms__ = None

@property
def _stats(self):
return {
"formation": {
"energy": {
"mean": np.array([[-12.94348027, -9.83037297]]),
"std": np.array([[4.39971409, 3.3574188]]),
},
"forces": NOT_DEFINED,
},
"total": {
"energy": {
"mean": np.array([[-89.44242, -1740.5336]]),
"std": np.array([[29.599571, 791.48663]]),
},
"forces": NOT_DEFINED,
},
}

def _post_init(self, overwrite_local_cache, energy_unit, distance_unit) -> None:
self.setup_dummy()
return super()._post_init(overwrite_local_cache, energy_unit, distance_unit)
Expand Down Expand Up @@ -90,3 +74,55 @@ def read_raw_entries(self):

def __len__(self):
return 9999


class PredefinedDataset(BaseDataset):
__name__ = "predefineddataset"
__energy_methods__ = [PotentialMethod.WB97M_D3BJ_DEF2_TZVPPD] # "wb97m-d3bj/def2-tzvppd"]
__force_mask__ = [True]
__energy_unit__ = "hartree"
__distance_unit__ = "bohr"
__forces_unit__ = "hartree/bohr"
force_target_names = __energy_methods__
energy_target_names = __energy_methods__

@property
def preprocess_path(self, overwrite_local_cache=False):
from os.path import join as p_join

from openqdc import get_project_root

return p_join(get_project_root(), "tests", "files", self.__name__, "preprocessed")

def is_preprocessed(self):
return True

def read_raw_entries(self):
pass

def read_preprocess(self, overwrite_local_cache=False):
logger.info("Reading preprocessed data.")
logger.info(
f"Dataset {self.__name__} with the following units:\n\
Energy: {self.energy_unit},\n\
Distance: {self.distance_unit},\n\
Forces: {self.force_unit if self.force_methods else 'None'}"
)
self.data = {}
for key in self.data_keys:
print(key, self.data_shapes[key], self.data_types[key])
filename = p_join(self.preprocess_path, f"{key}.mmap")
self.data[key] = np.memmap(filename, mode="r", dtype=self.data_types[key]).reshape(*self.data_shapes[key])

filename = p_join(self.preprocess_path, "props.pkl")
with open(filename, "rb") as f:
tmp = pkl.load(f)
for key in ["name", "subset", "n_atoms"]:
x = tmp.pop(key)
if len(x) == 2:
self.data[key] = x[0][x[1]]
else:
self.data[key] = x

for key in self.data:
logger.info(f"Loaded {key} with shape {self.data[key].shape}, dtype {self.data[key].dtype}")
Loading
Loading