Skip to content

Commit

Permalink
Fix tests, unit enums
Browse files Browse the repository at this point in the history
  • Loading branch information
FNTwin committed Jun 21, 2024
1 parent 1885a0b commit 2114551
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 37 deletions.
37 changes: 20 additions & 17 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from openqdc.utils.package_utils import has_package, requires_package
from openqdc.utils.regressor import Regressor # noqa
from openqdc.utils.units import get_conversion
from openqdc.utils.units import get_conversion, EnergyTypeConversion, DistanceTypeConversion, ForceTypeConversion

if has_package("torch"):
import torch
Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(
set_cache_dir(cache_dir)
# self._init_lambda_fn()
self.data = None
self._original_unit = self.__energy_unit__
self._original_unit = self.energy_unit
self.recompute_statistics = recompute_statistics
self.regressor_kwargs = regressor_kwargs
self.transform = transform
Expand Down Expand Up @@ -225,24 +225,24 @@ def e0s_dispatcher(self):
def _convert_data(self):
logger.info(
f"Converting {self.__name__} data to the following units:\n\
Energy: {self.energy_unit},\n\
Distance: {self.distance_unit},\n\
Forces: {self.force_unit if self.__force_methods__ else 'None'}"
Energy: {str(self.energy_unit)},\n\
Distance: {str(self.distance_unit)},\n\
Forces: {str(self.force_unit) if self.__force_methods__ else 'None'}"
)
for key in self.data_keys:
self.data[key] = self._convert_on_loading(self.data[key], key)

@property
def energy_unit(self):
return self.__energy_unit__
return EnergyTypeConversion(self.__energy_unit__)

@property
def distance_unit(self):
return self.__distance_unit__
return DistanceTypeConversion(self.__distance_unit__)

@property
def force_unit(self):
return self.__forces_unit__
return ForceTypeConversion(*self.__forces_unit__.split("/"))

@property
def root(self):
Expand Down Expand Up @@ -298,17 +298,18 @@ def _set_units(self, en, ds):
self.set_energy_unit(en)
self.set_distance_unit(ds)
if self.__force_methods__:
self.__forces_unit__ = self.energy_unit + "/" + self.distance_unit
self._fn_forces = get_conversion(old_en + "/" + old_ds, self.__forces_unit__)

#self.__forces_unit__ = str(self.energy_unit) + "/" + str(self.distance_unit)
self._fn_forces = self.force_unit.to(str(self.energy_unit),str(self.distance_unit)) #get_conversion(old_en + "/" + old_ds, self.__forces_unit__)
self.__forces_unit__ = str(self.energy_unit)+"/"+str(self.distance_unit)

def _set_isolated_atom_energies(self):
if self.__energy_methods__ is None:
logger.error("No energy methods defined for this dataset.")
if self.energy_type == "formation":
f = get_conversion("hartree", self.__energy_unit__)
else:
# regression are calculated on the original unit of the dataset
f = get_conversion(self._original_unit, self.__energy_unit__)
f = self._original_unit.to(self.energy_unit)
self.__isolated_atom_energies__ = f(self.e0s_dispatcher.e0s_matrix)

def convert_energy(self, x):
Expand All @@ -324,18 +325,20 @@ def set_energy_unit(self, value: str):
"""
Set a new energy unit for the dataset.
"""
old_unit = self.energy_unit
#old_unit = self.energy_unit
#self.__energy_unit__ = value
self._fn_energy = self.energy_unit.to(value) #get_conversion(old_unit, value)
self.__energy_unit__ = value
self._fn_energy = get_conversion(old_unit, value)

def set_distance_unit(self, value: str):
"""
Set a new distance unit for the dataset.
"""
old_unit = self.distance_unit
#old_unit = self.distance_unit
# self.__distance_unit__ = value
self._fn_distance = self.distance_unit.to(value) #get_conversion(old_unit, value)
self.__distance_unit__ = value
self._fn_distance = get_conversion(old_unit, value)


def set_array_format(self, format: str):
assert format in ["numpy", "torch", "jax"], f"Format {format} not supported."
self.array_format = format
Expand Down
16 changes: 15 additions & 1 deletion openqdc/datasets/potential/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,21 @@ class ANI1CCX_V2(ANI1CCX):


class ANI2X(ANI1):
""" """
"""
The ANI-2X dataset was constructed using active learning from modified versions of GDB-11, CheMBL,
and s66x8. It adds three new elements (F, Cl, S) resulting in 4.6 million conformers from 13k
chemical isomers, optimized using the LBFGS algorithm and labeled with ωB97X/6-31G*.
Usage
```python
from openqdc.datasets import ANI@X
dataset = ANI2X()
```
References:
- ANI-2x: https://doi.org/10.1021/acs.jctc.0c00121
- Github: https://github.com/aiqm/ANI1x_datasets
"""

__name__ = "ani2x"
__energy_unit__ = "hartree"
Expand Down
2 changes: 1 addition & 1 deletion openqdc/methods/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class StrEnum(str, Enum):
def __str__(self):
return self.value
return self.value.lower()


@unique
Expand Down
122 changes: 104 additions & 18 deletions openqdc/utils/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,77 @@
"""

from typing import Callable
from enum import Enum, unique

from openqdc.utils.exceptions import ConversionAlreadyDefined, ConversionNotDefinedError

class StrEnum(str, Enum):
def __str__(self):
return self.value.lower()

class ConversionEnum(Enum):

@classmethod
def list(cls):
return list(map(lambda c: c.value, cls))

@unique
class EnergyTypeConversion(ConversionEnum, StrEnum):
KCAL_MOL = "kcal/mol"
KJ_MOL = "kj/mol"
HARTREE = "hartree"
EV = "ev"
MEV = "mev"
RYD = "ryd"


def to(self, energy: "EnergyTypeConversion"):
return get_conversion(str(self), str(energy))

@unique
class DistanceTypeConversion(ConversionEnum, StrEnum):
ANG = "ang"
NM = "nm"
BOHR = "bohr"

def to(self, distance: "DistanceTypeConversion", fraction : bool = False):
return get_conversion(str(self), str(distance)) if not fraction else get_conversion(str(distance), str(self))

@unique
class ForceTypeConversion(ConversionEnum):
# Name = EnergyTypeConversion, , DistanceTypeConversion
HARTREE_BOHR = EnergyTypeConversion.HARTREE , DistanceTypeConversion.BOHR
HARTREE_ANG = EnergyTypeConversion.HARTREE , DistanceTypeConversion.ANG
HARTREE_NM = EnergyTypeConversion.HARTREE , DistanceTypeConversion.NM
EV_BOHR = EnergyTypeConversion.EV , DistanceTypeConversion.BOHR
EV_ANG = EnergyTypeConversion.EV , DistanceTypeConversion.ANG
EV_NM = EnergyTypeConversion.EV , DistanceTypeConversion.NM
KCAL_MOL_BOHR = EnergyTypeConversion.KCAL_MOL , DistanceTypeConversion.BOHR
KCAL_MOL_ANG = EnergyTypeConversion.KCAL_MOL , DistanceTypeConversion.ANG
KCAL_MOL_NM = EnergyTypeConversion.KCAL_MOL , DistanceTypeConversion.NM
KJ_MOL_BOHR = EnergyTypeConversion.KJ_MOL , DistanceTypeConversion.BOHR
KJ_MOL_ANG = EnergyTypeConversion.KJ_MOL , DistanceTypeConversion.ANG
KJ_MOL_NM = EnergyTypeConversion.KJ_MOL , DistanceTypeConversion.NM
MEV_BOHR = EnergyTypeConversion.MEV , DistanceTypeConversion.BOHR
MEV_ANG = EnergyTypeConversion.MEV , DistanceTypeConversion.ANG
MEV_NM = EnergyTypeConversion.MEV , DistanceTypeConversion.NM
RYD_BOHR = EnergyTypeConversion.RYD , DistanceTypeConversion.BOHR
RYD_ANG = EnergyTypeConversion.RYD , DistanceTypeConversion.ANG
RYD_NM = EnergyTypeConversion.RYD , DistanceTypeConversion.NM

def __init__(self,
energy: EnergyTypeConversion,
distance: DistanceTypeConversion):
self.energy = energy
self.distance = distance

def __str__(self):
return f"{self.energy}/{self.distance}"

def to(self, energy: EnergyTypeConversion, distance: DistanceTypeConversion):
return lambda x : self.distance.to(distance, fraction=True)(self.energy.to(energy)(x))


CONVERSION_REGISTRY = {}


Expand All @@ -27,7 +95,11 @@ class Conversion:
The callable to compute the conversion
"""

def __init__(self, in_unit: str, out_unit: str, func: Callable[[float], float]):
def __init__(self,
in_unit: str,
out_unit: str,
func: Callable[[float], float]
):
"""
Parameters
Expand Down Expand Up @@ -68,21 +140,35 @@ def get_conversion(in_unit: str, out_unit: str):
Conversion("ev", "kj/mol", lambda x: x * 96.4853)
Conversion("mev", "ev", lambda x: x * 1000.0)
Conversion("ev", "mev", lambda x: x * 0.0001)
Conversion("ev", "ryd", lambda x: x * 0.07349864)

# kcal/mol conversion
Conversion("kcal/mol", "ev", lambda x: x * 0.0433641)
Conversion("kcal/mol", "hartree", lambda x: x * 0.00159362)
Conversion("kcal/mol", "kj/mol", lambda x: x * 4.184)
Conversion("kcal/mol", "mev", lambda x: get_conversion("ev", "mev")(get_conversion("kcal/mol", "ev")(x)))
Conversion("kcal/mol", "ryd", lambda x: x * 0.00318720)

# hartree conversion
Conversion("hartree", "ev", lambda x: x * 27.211386246)
Conversion("hartree", "kcal/mol", lambda x: x * 627.509)
Conversion("hartree", "kj/mol", lambda x: x * 2625.5)
Conversion("hartree", "mev", lambda x: get_conversion("ev", "mev")(get_conversion("hartree", "ev")(x)))
Conversion("hartree", "ryd", lambda x: x * 2.0)

# kj/mol conversion
Conversion("kj/mol", "ev", lambda x: x * 0.0103643)
Conversion("kj/mol", "kcal/mol", lambda x: x * 0.239006)
Conversion("kj/mol", "hartree", lambda x: x * 0.000380879)
Conversion("kj/mol", "mev", lambda x: get_conversion("ev", "mev")(get_conversion("kj/mol", "ev")(x)))
Conversion("kj/mol", "ryd", lambda x: x * 0.000301318)

# Rydberg conversion
Conversion("ryd", "ev", lambda x: x * 13.60569301)
Conversion("ryd", "kcal/mol", lambda x: x * 313.7545)
Conversion("ryd", "hartree", lambda x: x * 0.5)
Conversion("ryd", "kj/mol", lambda x: x * 1312.75)
Conversion("ryd", "mev", lambda x: get_conversion("ev", "mev")(get_conversion("ryd", "ev")(x)))

# distance conversions
Conversion("bohr", "ang", lambda x: x * 0.52917721092)
Expand All @@ -92,20 +178,20 @@ def get_conversion(in_unit: str, out_unit: str):
Conversion("nm", "bohr", lambda x: x * 18.8973)
Conversion("bohr", "nm", lambda x: x / 18.8973)

# common forces conversion
Conversion("hartree/bohr", "ev/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("hartree", "ev")(x)))
Conversion("hartree/bohr", "ev/bohr", lambda x: get_conversion("hartree", "ev")(x))
Conversion("hartree/bohr", "kcal/mol/bohr", lambda x: get_conversion("hartree", "kcal/mol")(x))
Conversion(
"hartree/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("hartree", "kcal/mol")(x))
)
Conversion("hartree/ang", "kcal/mol/ang", lambda x: get_conversion("hartree", "kcal/mol")(x))
Conversion("hartree/ang", "hartree/bohr", lambda x: get_conversion("bohr", "ang")(x))
Conversion("hartree/bohr", "hartree/ang", lambda x: get_conversion("ang", "bohr")(x))
Conversion("kcal/mol/bohr", "hartree/bohr", lambda x: get_conversion("kcal/mol", "hartree")(x))
Conversion("ev/ang", "hartree/ang", lambda x: get_conversion("ev", "hartree")(x))
Conversion("ev/bohr", "hartree/bohr", lambda x: get_conversion("ev", "hartree")(x))
Conversion("ev/bohr", "ev/ang", lambda x: get_conversion("ang", "bohr")(x))
Conversion("ev/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("ev", "kcal/mol")(x)))
Conversion("kcal/mol/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(x))
Conversion("ev/ang", "kcal/mol/ang", lambda x: get_conversion("ev", "kcal/mol")(x))
## common forces conversion
#Conversion("hartree/bohr", "ev/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("hartree", "ev")(x)))
#Conversion("hartree/bohr", "ev/bohr", lambda x: get_conversion("hartree", "ev")(x))
#Conversion("hartree/bohr", "kcal/mol/bohr", lambda x: get_conversion("hartree", "kcal/mol")(x))
#Conversion(
# "hartree/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("hartree", "kcal/mol")(x))
#)
#Conversion("hartree/ang", "kcal/mol/ang", lambda x: get_conversion("hartree", "kcal/mol")(x))
#Conversion("hartree/ang", "hartree/bohr", lambda x: get_conversion("bohr", "ang")(x))
#Conversion("hartree/bohr", "hartree/ang", lambda x: get_conversion("ang", "bohr")(x))
#Conversion("kcal/mol/bohr", "hartree/bohr", lambda x: get_conversion("kcal/mol", "hartree")(x))
#Conversion("ev/ang", "hartree/ang", lambda x: get_conversion("ev", "hartree")(x))
#Conversion("ev/bohr", "hartree/bohr", lambda x: get_conversion("ev", "hartree")(x))
#Conversion("ev/bohr", "ev/ang", lambda x: get_conversion("ang", "bohr")(x))
#Conversion("ev/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("ev", "kcal/mol")(x)))
#Conversion("kcal/mol/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(x))
#Conversion("ev/ang", "kcal/mol/ang", lambda x: get_conversion("ev", "kcal/mol")(x))
9 changes: 9 additions & 0 deletions tests/test_filedataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from io import StringIO
import os

import numpy as np
import pytest

from openqdc.datasets.io import XYZDataset
from openqdc.methods.enums import PotentialMethod
from openqdc.utils.package_utils import has_package
from openqdc.utils.io import get_local_cache

if has_package("torch"):
import torch
Expand All @@ -20,6 +22,13 @@
}


@pytest.fixture(autouse=True)
def clean_before_run():
# start by removing any cached data
cache_dir = get_local_cache()
os.system(f"rm -rf {cache_dir}/XYZDataset")
yield

@pytest.fixture
def xyz_filelike():
xyz_str = """3
Expand Down

0 comments on commit 2114551

Please sign in to comment.