Skip to content

Commit

Permalink
Merge pull request #85 from OpenDrugDiscovery/atom_ener_structure
Browse files Browse the repository at this point in the history
Improved atom energy code + fixes
  • Loading branch information
FNTwin authored May 2, 2024
2 parents 051c084 + 2f4b692 commit dcd1bea
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 60 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,5 @@ jobs:
- name: Run tests
run: python -m pytest

- name: Test building the doc
run: mkdocs build
#- name: Test building the doc
# run: mkdocs build
4 changes: 2 additions & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ docs_dir: "docs"
nav:
- Overview: index.md
- Available Datasets: datasets.md
- Tutorials:
- Really hard example: tutorials/usage.ipynb
#- Tutorials:
# #- Really hard example: tutorials/usage.ipynb
- API:
- Datasets: API/available_datasets.md
- Isolated Atoms Energies: API/isolated_atom_energies.md
Expand Down
53 changes: 10 additions & 43 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
set_cache_dir(cache_dir)
# self._init_lambda_fn()
self.data = None
self._original_unit = self.__energy_unit__
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
self.transform = transform
Expand Down Expand Up @@ -268,6 +269,10 @@ def pkl_data_keys(self):
def pkl_data_types(self):
return {"name": str, "subset": str, "n_atoms": np.int32}

@property
def atom_energies(self):
return self._e0s_dispatcher

@property
def data_types(self):
return {
Expand Down Expand Up @@ -299,7 +304,11 @@ def _set_units(self, en, ds):
def _set_isolated_atom_energies(self):
if self.__energy_methods__ is None:
logger.error("No energy methods defined for this dataset.")
f = get_conversion("hartree", self.__energy_unit__)
if self.energy_type == "formation":
f = get_conversion("hartree", self.__energy_unit__)
else:
# regression are calculated on the original unit of the dataset
f = get_conversion(self._original_unit, self.__energy_unit__)
self.__isolated_atom_energies__ = f(self.e0s_dispatcher.e0s_matrix)

def convert_energy(self, x):
Expand Down Expand Up @@ -558,48 +567,6 @@ def wrapper(idx):
datum["idxs"] = idxs
return datum

@classmethod
def as_dataloader(
cls,
batch_size: int = 8,
energy_unit: Optional[str] = None,
distance_unit: Optional[str] = None,
array_format: str = "torch",
energy_type: str = "formation",
overwrite_local_cache: bool = False,
cache_dir: Optional[str] = None,
recompute_statistics: bool = False,
transform: Optional[Callable] = None,
):
"""
Return the dataset as a dataloader.
Parameters
----------
batch_size : int, optional
Batch size, by default 8
For other parameters, see the __init__ method.
"""
if not has_package("torch_geometric"):
raise ImportError("torch_geometric is required to use this method.")
assert array_format in ["torch", "jax"], f"Format {array_format} must be torch or jax."
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

return DataLoader(
cls(
energy_unit=energy_unit,
distance_unit=distance_unit,
array_format=array_format,
energy_type=energy_type,
overwrite_local_cache=overwrite_local_cache,
cache_dir=cache_dir,
recompute_statistics=recompute_statistics,
transform=lambda x: Data(**x) if transform is None else transform,
),
batch_size=batch_size,
)

def as_iter(self, atoms: bool = False, energy_method: int = 0):
"""
Return the dataset as an iterator.
Expand Down
150 changes: 138 additions & 12 deletions openqdc/datasets/energies.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from os.path import join as p_join
from typing import Dict, Union

import numpy as np
from loguru import logger

from openqdc.methods.enums import PotentialMethod
from openqdc.utils.constants import ATOM_SYMBOLS, ATOMIC_NUMBERS
from openqdc.utils.io import load_pkl, save_pkl
from openqdc.utils.regressor import Regressor

Expand Down Expand Up @@ -41,6 +44,56 @@ def dispatch_factory(data, **kwargs) -> "IsolatedEnergyInterface":
return NullEnergy(data, **kwargs)


@dataclass(frozen=False, eq=True)
class AtomSpecies:
"""
Structure that defines a tuple of chemical specie and charge
and provide hash and automatic conversion from atom number to
checmical symbol
"""

symbol: Union[str, int]
charge: int = 0

def __post_init__(self):
if not isinstance(self.symbol, str):
self.symbol = ATOM_SYMBOLS[self.symbol]
self.number = ATOMIC_NUMBERS[self.symbol]

def __hash__(self):
return hash((self.symbol, self.charge))

def __eq__(self, other):
if not isinstance(other, AtomSpecies):
symbol, charge = other[0], other[1]
other = AtomSpecies(symbol=symbol, charge=charge)
return (self.number, self.charge) == (other.number, other.charge)


@dataclass
class AtomEnergy:
"""
Datastructure to store isolated atom energies
and the std deviation associated to the value.
By default the std will be 1 if no value was calculated
or not available (formation energy case)
"""

mean: np.array
std: np.array = field(default_factory=lambda: np.array([1], dtype=np.float32))

def __post_init__(self):
if not isinstance(self.mean, np.ndarray):
self.mean = np.array([self.mean], dtype=np.float32)

def append(self, other: "AtomEnergy"):
"""
Append the mean and std of another atom energy
"""
self.mean = np.append(self.mean, other.mean)
self.std = np.append(self.std, other.std)


class AtomEnergies:
"""
Manager class for interface with the isolated atom energies classes
Expand Down Expand Up @@ -71,15 +124,49 @@ def e0s_matrix(self) -> np.ndarray:
"""
return self.factory.e0_matrix

@property
def e0s_dict(self) -> Dict[AtomSpecies, AtomEnergy]:
"""
Return the isolated atom energies dictionary
"""
return self.factory.e0_dict

def __str__(self):
return f"Atoms: { list(set(map(lambda x : x.symbol, self.e0s_dict.keys())))}"

def __repr__(self):
return str(self)

def __getitem__(self, item: AtomSpecies) -> AtomEnergy:
"""
Retrieve a key from the isolated atom dictionary.
Item can be written as tuple(Symbol, charge),
tuple(Chemical number, charge). If no charge is passed,
it will be automatically set to 0.
Examples:
AtomEnergies[6], AtomEnergies[6,1],
AtomEnergies["C",1], AtomEnergies[(6,1)]
AtomEnergies[("C,1)]
"""
try:
atom, charge = item[0], item[1]
except TypeError:
atom = item
charge = 0
except IndexError:
atom = item[0]
charge = 0
if not isinstance(atom, str):
atom = ATOM_SYMBOLS[atom]
return self.e0s_dict[(atom, charge)]


class IsolatedEnergyInterface(ABC):
"""
Abstract class that defines the interface for the
different implementation of an isolated atom energy value
"""

_e0_matrixs = []

def __init__(self, data, **kwargs):
"""
Parameters
Expand All @@ -93,7 +180,8 @@ def __init__(self, data, **kwargs):
selected energy class. Mostly used for regression
to pass the regressor_kwargs.
"""

self._e0_matrixs = []
self._e0_dict = None
self.kwargs = kwargs
self.data = data
self._post_init()
Expand All @@ -120,27 +208,61 @@ def e0_matrix(self) -> np.ndarray:
"""
return np.array(self._e0_matrixs)

@property
def e0_dict(self) -> Dict:
"""
Return the isolated atom energies dict
"""

return self._e0s_dict

def __str__(self) -> str:
return self.__class__.__name__.lower()


class NullEnergy(IsolatedEnergyInterface):
class PhysicalEnergy(IsolatedEnergyInterface):
"""
Class that returns a null (zeros) matrix for the isolated atom energies in case
of no energies are available.
Class that returns a physical (SE,DFT,etc) isolated atom energies.
"""

def _assembly_e0_dict(self):
datum = {}
for method in self.data.__energy_methods__:
for key, values in method.atom_energies_dict.items():
atm = AtomSpecies(*key)
ens = AtomEnergy(values)
if atm not in datum:
datum[atm] = ens
else:
datum[atm].append(ens)
self._e0s_dict = datum

def _post_init(self):
self._e0_matrixs = [PotentialMethod.NONE.atom_energies_matrix for _ in range(len(self.data.energy_methods))]
self._e0_matrixs = [energy_method.atom_energies_matrix for energy_method in self.data.__energy_methods__]
self._assembly_e0_dict()


class PhysicalEnergy(IsolatedEnergyInterface):
class NullEnergy(IsolatedEnergyInterface):
"""
Class that returns a physical (SE,DFT,etc) isolated atom energies.
Class that returns a null (zeros) matrix for the isolated atom energies in case
of no energies are available.
"""

def _assembly_e0_dict(self):
datum = {}
for _ in self.data.__energy_methods__:
for key, values in PotentialMethod.NONE.atom_energies_dict.items():
atm = AtomSpecies(*key)
ens = AtomEnergy(values)
if atm not in datum:
datum[atm] = ens
else:
datum[atm].append(ens)
self._e0s_dict = datum

def _post_init(self):
self._e0_matrixs = [energy_method.atom_energies_matrix for energy_method in self.data.__energy_methods__]
self._e0_matrixs = [PotentialMethod.NONE.atom_energies_matrix for _ in range(len(self.data.energy_methods))]
self._assembly_e0_dict()


class RegressionEnergy(IsolatedEnergyInterface):
Expand Down Expand Up @@ -175,7 +297,9 @@ def _set_lin_atom_species_dict(self, E0s, covs) -> None:
"""
atomic_energies_dict = {}
for i, z in enumerate(self.regressor.numbers):
atomic_energies_dict[z] = E0s[i]
for charge in range(-10, 11):
atomic_energies_dict[AtomSpecies(z, charge)] = AtomEnergy(E0s[i], 1 if covs is None else covs[i])
# atomic_energies_dict[z] = E0s[i]
self._e0s_dict = atomic_energies_dict
self.save_e0s()

Expand All @@ -187,7 +311,9 @@ def _set_linear_e0s(self) -> None:
new_e0s = [np.zeros((max(self.data.numbers) + 1, MAX_CHARGE_NUMBER)) for _ in range(len(self))]
for z, e0 in self._e0s_dict.items():
for i in range(len(self)):
new_e0s[i][z, :] = e0[i]
# new_e0s[i][z, :] = e0[i]
new_e0s[i][z.number, z.charge] = e0.mean[i]
# for atom_sp, values in
self._e0_matrixs = new_e0s

def save_e0s(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions openqdc/datasets/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
self.recompute_statistics = True
self.refit_e0s = True
self.energy_type = energy_type
self._original_unit = energy_unit
self.__energy_unit__ = energy_unit
self.__distance_unit__ = distance_unit
self.__energy_methods__ = [PotentialMethod.NONE if not level_of_theory else level_of_theory]
Expand Down
11 changes: 10 additions & 1 deletion openqdc/methods/enums.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from enum import Enum

from loguru import logger
from numpy import array, float32

from openqdc.methods.atom_energies import atom_energy_collection, to_e_matrix
from openqdc.utils.constants import ATOM_SYMBOLS


class StrEnum(str, Enum):
Expand Down Expand Up @@ -472,6 +474,13 @@ class PotentialMethod(QmMethod): # SPLIT FOR INTERACTIO ENERGIES AND FIX MD1
XLYP_TZP = Functional.XLYP, BasisSet.TZP
NONE = Functional.NONE, BasisSet.NONE

def _build_default_dict(self):
e0_dict = {}
for SYMBOL in ATOM_SYMBOLS:
for CHARGE in range(-10, 11):
e0_dict[(SYMBOL, CHARGE)] = array([0], dtype=float32)
return e0_dict

@property
def atom_energies_dict(self):
"""Get the atomization energy dictionary"""
Expand All @@ -483,7 +492,7 @@ def atom_energies_dict(self):
raise
except: # noqa
logger.info(f"No available atomization energy for the QM method {key}. All values are set to 0.")

energies = self._build_default_dict()
return energies


Expand Down
2 changes: 2 additions & 0 deletions openqdc/utils/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def solve(self):
else:
X, y = self.X, self.y[:, energy_idx]
E0s, cov = self.solver(X, y)
if cov is None:
cov = np.zeros_like(E0s) + 1.0
E0_list.append(E0s)
cov_list.append(cov)
return np.vstack(E0_list).T, np.vstack(cov_list).T
Expand Down
Loading

0 comments on commit dcd1bea

Please sign in to comment.