diff --git a/openqdc/datasets/base.py b/openqdc/datasets/base.py index db04fcd..0c94ec5 100644 --- a/openqdc/datasets/base.py +++ b/openqdc/datasets/base.py @@ -38,7 +38,7 @@ ) from openqdc.utils.package_utils import has_package, requires_package from openqdc.utils.regressor import Regressor # noqa -from openqdc.utils.units import get_conversion +from openqdc.utils.units import get_conversion, EnergyTypeConversion, DistanceTypeConversion, ForceTypeConversion if has_package("torch"): import torch @@ -129,7 +129,7 @@ def __init__( set_cache_dir(cache_dir) # self._init_lambda_fn() self.data = None - self._original_unit = self.__energy_unit__ + self._original_unit = self.energy_unit self.recompute_statistics = recompute_statistics self.regressor_kwargs = regressor_kwargs self.transform = transform @@ -225,24 +225,24 @@ def e0s_dispatcher(self): def _convert_data(self): logger.info( f"Converting {self.__name__} data to the following units:\n\ - Energy: {self.energy_unit},\n\ - Distance: {self.distance_unit},\n\ - Forces: {self.force_unit if self.__force_methods__ else 'None'}" + Energy: {str(self.energy_unit)},\n\ + Distance: {str(self.distance_unit)},\n\ + Forces: {str(self.force_unit) if self.__force_methods__ else 'None'}" ) for key in self.data_keys: self.data[key] = self._convert_on_loading(self.data[key], key) @property def energy_unit(self): - return self.__energy_unit__ + return EnergyTypeConversion(self.__energy_unit__) @property def distance_unit(self): - return self.__distance_unit__ + return DistanceTypeConversion(self.__distance_unit__) @property def force_unit(self): - return self.__forces_unit__ + return ForceTypeConversion(*self.__forces_unit__.split("/")) @property def root(self): @@ -298,9 +298,10 @@ def _set_units(self, en, ds): self.set_energy_unit(en) self.set_distance_unit(ds) if self.__force_methods__: - self.__forces_unit__ = self.energy_unit + "/" + self.distance_unit - self._fn_forces = get_conversion(old_en + "/" + old_ds, self.__forces_unit__) - + #self.__forces_unit__ = str(self.energy_unit) + "/" + str(self.distance_unit) + self._fn_forces = self.force_unit.to(str(self.energy_unit),str(self.distance_unit)) #get_conversion(old_en + "/" + old_ds, self.__forces_unit__) + self.__forces_unit__ = str(self.energy_unit)+"/"+str(self.distance_unit) + def _set_isolated_atom_energies(self): if self.__energy_methods__ is None: logger.error("No energy methods defined for this dataset.") @@ -308,7 +309,7 @@ def _set_isolated_atom_energies(self): f = get_conversion("hartree", self.__energy_unit__) else: # regression are calculated on the original unit of the dataset - f = get_conversion(self._original_unit, self.__energy_unit__) + f = self._original_unit.to(self.energy_unit) self.__isolated_atom_energies__ = f(self.e0s_dispatcher.e0s_matrix) def convert_energy(self, x): @@ -324,18 +325,20 @@ def set_energy_unit(self, value: str): """ Set a new energy unit for the dataset. """ - old_unit = self.energy_unit + #old_unit = self.energy_unit + #self.__energy_unit__ = value + self._fn_energy = self.energy_unit.to(value) #get_conversion(old_unit, value) self.__energy_unit__ = value - self._fn_energy = get_conversion(old_unit, value) def set_distance_unit(self, value: str): """ Set a new distance unit for the dataset. """ - old_unit = self.distance_unit + #old_unit = self.distance_unit + # self.__distance_unit__ = value + self._fn_distance = self.distance_unit.to(value) #get_conversion(old_unit, value) self.__distance_unit__ = value - self._fn_distance = get_conversion(old_unit, value) - + def set_array_format(self, format: str): assert format in ["numpy", "torch", "jax"], f"Format {format} not supported." self.array_format = format diff --git a/openqdc/datasets/potential/ani.py b/openqdc/datasets/potential/ani.py index b31d658..bcff384 100644 --- a/openqdc/datasets/potential/ani.py +++ b/openqdc/datasets/potential/ani.py @@ -210,7 +210,21 @@ class ANI1CCX_V2(ANI1CCX): class ANI2X(ANI1): - """ """ + """ + The ANI-2X dataset was constructed using active learning from modified versions of GDB-11, CheMBL, + and s66x8. It adds three new elements (F, Cl, S) resulting in 4.6 million conformers from 13k + chemical isomers, optimized using the LBFGS algorithm and labeled with ωB97X/6-31G*. + + Usage + ```python + from openqdc.datasets import ANI@X + dataset = ANI2X() + ``` + + References: + - ANI-2x: https://doi.org/10.1021/acs.jctc.0c00121 + - Github: https://github.com/aiqm/ANI1x_datasets + """ __name__ = "ani2x" __energy_unit__ = "hartree" diff --git a/openqdc/methods/enums.py b/openqdc/methods/enums.py index 6689f3a..8270d13 100644 --- a/openqdc/methods/enums.py +++ b/openqdc/methods/enums.py @@ -9,7 +9,7 @@ class StrEnum(str, Enum): def __str__(self): - return self.value + return self.value.lower() @unique diff --git a/openqdc/utils/units.py b/openqdc/utils/units.py index 12c13f9..2b3529f 100644 --- a/openqdc/utils/units.py +++ b/openqdc/utils/units.py @@ -9,9 +9,77 @@ """ from typing import Callable +from enum import Enum, unique from openqdc.utils.exceptions import ConversionAlreadyDefined, ConversionNotDefinedError +class StrEnum(str, Enum): + def __str__(self): + return self.value.lower() + +class ConversionEnum(Enum): + + @classmethod + def list(cls): + return list(map(lambda c: c.value, cls)) + +@unique +class EnergyTypeConversion(ConversionEnum, StrEnum): + KCAL_MOL = "kcal/mol" + KJ_MOL = "kj/mol" + HARTREE = "hartree" + EV = "ev" + MEV = "mev" + RYD = "ryd" + + + def to(self, energy: "EnergyTypeConversion"): + return get_conversion(str(self), str(energy)) + +@unique +class DistanceTypeConversion(ConversionEnum, StrEnum): + ANG = "ang" + NM = "nm" + BOHR = "bohr" + + def to(self, distance: "DistanceTypeConversion", fraction : bool = False): + return get_conversion(str(self), str(distance)) if not fraction else get_conversion(str(distance), str(self)) + +@unique +class ForceTypeConversion(ConversionEnum): + # Name = EnergyTypeConversion, , DistanceTypeConversion + HARTREE_BOHR = EnergyTypeConversion.HARTREE , DistanceTypeConversion.BOHR + HARTREE_ANG = EnergyTypeConversion.HARTREE , DistanceTypeConversion.ANG + HARTREE_NM = EnergyTypeConversion.HARTREE , DistanceTypeConversion.NM + EV_BOHR = EnergyTypeConversion.EV , DistanceTypeConversion.BOHR + EV_ANG = EnergyTypeConversion.EV , DistanceTypeConversion.ANG + EV_NM = EnergyTypeConversion.EV , DistanceTypeConversion.NM + KCAL_MOL_BOHR = EnergyTypeConversion.KCAL_MOL , DistanceTypeConversion.BOHR + KCAL_MOL_ANG = EnergyTypeConversion.KCAL_MOL , DistanceTypeConversion.ANG + KCAL_MOL_NM = EnergyTypeConversion.KCAL_MOL , DistanceTypeConversion.NM + KJ_MOL_BOHR = EnergyTypeConversion.KJ_MOL , DistanceTypeConversion.BOHR + KJ_MOL_ANG = EnergyTypeConversion.KJ_MOL , DistanceTypeConversion.ANG + KJ_MOL_NM = EnergyTypeConversion.KJ_MOL , DistanceTypeConversion.NM + MEV_BOHR = EnergyTypeConversion.MEV , DistanceTypeConversion.BOHR + MEV_ANG = EnergyTypeConversion.MEV , DistanceTypeConversion.ANG + MEV_NM = EnergyTypeConversion.MEV , DistanceTypeConversion.NM + RYD_BOHR = EnergyTypeConversion.RYD , DistanceTypeConversion.BOHR + RYD_ANG = EnergyTypeConversion.RYD , DistanceTypeConversion.ANG + RYD_NM = EnergyTypeConversion.RYD , DistanceTypeConversion.NM + + def __init__(self, + energy: EnergyTypeConversion, + distance: DistanceTypeConversion): + self.energy = energy + self.distance = distance + + def __str__(self): + return f"{self.energy}/{self.distance}" + + def to(self, energy: EnergyTypeConversion, distance: DistanceTypeConversion): + return lambda x : self.distance.to(distance, fraction=True)(self.energy.to(energy)(x)) + + CONVERSION_REGISTRY = {} @@ -27,7 +95,11 @@ class Conversion: The callable to compute the conversion """ - def __init__(self, in_unit: str, out_unit: str, func: Callable[[float], float]): + def __init__(self, + in_unit: str, + out_unit: str, + func: Callable[[float], float] + ): """ Parameters @@ -68,21 +140,35 @@ def get_conversion(in_unit: str, out_unit: str): Conversion("ev", "kj/mol", lambda x: x * 96.4853) Conversion("mev", "ev", lambda x: x * 1000.0) Conversion("ev", "mev", lambda x: x * 0.0001) +Conversion("ev", "ryd", lambda x: x * 0.07349864) # kcal/mol conversion Conversion("kcal/mol", "ev", lambda x: x * 0.0433641) Conversion("kcal/mol", "hartree", lambda x: x * 0.00159362) Conversion("kcal/mol", "kj/mol", lambda x: x * 4.184) +Conversion("kcal/mol", "mev", lambda x: get_conversion("ev", "mev")(get_conversion("kcal/mol", "ev")(x))) +Conversion("kcal/mol", "ryd", lambda x: x * 0.00318720) # hartree conversion Conversion("hartree", "ev", lambda x: x * 27.211386246) Conversion("hartree", "kcal/mol", lambda x: x * 627.509) Conversion("hartree", "kj/mol", lambda x: x * 2625.5) +Conversion("hartree", "mev", lambda x: get_conversion("ev", "mev")(get_conversion("hartree", "ev")(x))) +Conversion("hartree", "ryd", lambda x: x * 2.0) # kj/mol conversion Conversion("kj/mol", "ev", lambda x: x * 0.0103643) Conversion("kj/mol", "kcal/mol", lambda x: x * 0.239006) Conversion("kj/mol", "hartree", lambda x: x * 0.000380879) +Conversion("kj/mol", "mev", lambda x: get_conversion("ev", "mev")(get_conversion("kj/mol", "ev")(x))) +Conversion("kj/mol", "ryd", lambda x: x * 0.000301318) + +# Rydberg conversion +Conversion("ryd", "ev", lambda x: x * 13.60569301) +Conversion("ryd", "kcal/mol", lambda x: x * 313.7545) +Conversion("ryd", "hartree", lambda x: x * 0.5) +Conversion("ryd", "kj/mol", lambda x: x * 1312.75) +Conversion("ryd", "mev", lambda x: get_conversion("ev", "mev")(get_conversion("ryd", "ev")(x))) # distance conversions Conversion("bohr", "ang", lambda x: x * 0.52917721092) @@ -92,20 +178,20 @@ def get_conversion(in_unit: str, out_unit: str): Conversion("nm", "bohr", lambda x: x * 18.8973) Conversion("bohr", "nm", lambda x: x / 18.8973) -# common forces conversion -Conversion("hartree/bohr", "ev/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("hartree", "ev")(x))) -Conversion("hartree/bohr", "ev/bohr", lambda x: get_conversion("hartree", "ev")(x)) -Conversion("hartree/bohr", "kcal/mol/bohr", lambda x: get_conversion("hartree", "kcal/mol")(x)) -Conversion( - "hartree/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("hartree", "kcal/mol")(x)) -) -Conversion("hartree/ang", "kcal/mol/ang", lambda x: get_conversion("hartree", "kcal/mol")(x)) -Conversion("hartree/ang", "hartree/bohr", lambda x: get_conversion("bohr", "ang")(x)) -Conversion("hartree/bohr", "hartree/ang", lambda x: get_conversion("ang", "bohr")(x)) -Conversion("kcal/mol/bohr", "hartree/bohr", lambda x: get_conversion("kcal/mol", "hartree")(x)) -Conversion("ev/ang", "hartree/ang", lambda x: get_conversion("ev", "hartree")(x)) -Conversion("ev/bohr", "hartree/bohr", lambda x: get_conversion("ev", "hartree")(x)) -Conversion("ev/bohr", "ev/ang", lambda x: get_conversion("ang", "bohr")(x)) -Conversion("ev/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("ev", "kcal/mol")(x))) -Conversion("kcal/mol/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(x)) -Conversion("ev/ang", "kcal/mol/ang", lambda x: get_conversion("ev", "kcal/mol")(x)) +## common forces conversion +#Conversion("hartree/bohr", "ev/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("hartree", "ev")(x))) +#Conversion("hartree/bohr", "ev/bohr", lambda x: get_conversion("hartree", "ev")(x)) +#Conversion("hartree/bohr", "kcal/mol/bohr", lambda x: get_conversion("hartree", "kcal/mol")(x)) +#Conversion( +# "hartree/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("hartree", "kcal/mol")(x)) +#) +#Conversion("hartree/ang", "kcal/mol/ang", lambda x: get_conversion("hartree", "kcal/mol")(x)) +#Conversion("hartree/ang", "hartree/bohr", lambda x: get_conversion("bohr", "ang")(x)) +#Conversion("hartree/bohr", "hartree/ang", lambda x: get_conversion("ang", "bohr")(x)) +#Conversion("kcal/mol/bohr", "hartree/bohr", lambda x: get_conversion("kcal/mol", "hartree")(x)) +#Conversion("ev/ang", "hartree/ang", lambda x: get_conversion("ev", "hartree")(x)) +#Conversion("ev/bohr", "hartree/bohr", lambda x: get_conversion("ev", "hartree")(x)) +#Conversion("ev/bohr", "ev/ang", lambda x: get_conversion("ang", "bohr")(x)) +#Conversion("ev/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("ev", "kcal/mol")(x))) +#Conversion("kcal/mol/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(x)) +#Conversion("ev/ang", "kcal/mol/ang", lambda x: get_conversion("ev", "kcal/mol")(x)) diff --git a/tests/test_filedataset.py b/tests/test_filedataset.py index 8defc7f..27baf66 100644 --- a/tests/test_filedataset.py +++ b/tests/test_filedataset.py @@ -1,4 +1,5 @@ from io import StringIO +import os import numpy as np import pytest @@ -6,6 +7,7 @@ from openqdc.datasets.io import XYZDataset from openqdc.methods.enums import PotentialMethod from openqdc.utils.package_utils import has_package +from openqdc.utils.io import get_local_cache if has_package("torch"): import torch @@ -20,6 +22,13 @@ } +@pytest.fixture(autouse=True) +def clean_before_run(): + # start by removing any cached data + cache_dir = get_local_cache() + os.system(f"rm -rf {cache_dir}/XYZDataset") + yield + @pytest.fixture def xyz_filelike(): xyz_str = """3