Skip to content

Commit

Permalink
Black . , isort .
Browse files Browse the repository at this point in the history
  • Loading branch information
FNTwin committed Jun 21, 2024
1 parent 2114551 commit b3c3b02
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 73 deletions.
31 changes: 19 additions & 12 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@
)
from openqdc.utils.package_utils import has_package, requires_package
from openqdc.utils.regressor import Regressor # noqa
from openqdc.utils.units import get_conversion, EnergyTypeConversion, DistanceTypeConversion, ForceTypeConversion
from openqdc.utils.units import (
DistanceTypeConversion,
EnergyTypeConversion,
ForceTypeConversion,
get_conversion,
)

if has_package("torch"):
import torch
Expand Down Expand Up @@ -298,10 +303,12 @@ def _set_units(self, en, ds):
self.set_energy_unit(en)
self.set_distance_unit(ds)
if self.__force_methods__:
#self.__forces_unit__ = str(self.energy_unit) + "/" + str(self.distance_unit)
self._fn_forces = self.force_unit.to(str(self.energy_unit),str(self.distance_unit)) #get_conversion(old_en + "/" + old_ds, self.__forces_unit__)
self.__forces_unit__ = str(self.energy_unit)+"/"+str(self.distance_unit)

# self.__forces_unit__ = str(self.energy_unit) + "/" + str(self.distance_unit)
self._fn_forces = self.force_unit.to(
str(self.energy_unit), str(self.distance_unit)
) # get_conversion(old_en + "/" + old_ds, self.__forces_unit__)
self.__forces_unit__ = str(self.energy_unit) + "/" + str(self.distance_unit)

def _set_isolated_atom_energies(self):
if self.__energy_methods__ is None:
logger.error("No energy methods defined for this dataset.")
Expand All @@ -325,20 +332,20 @@ def set_energy_unit(self, value: str):
"""
Set a new energy unit for the dataset.
"""
#old_unit = self.energy_unit
#self.__energy_unit__ = value
self._fn_energy = self.energy_unit.to(value) #get_conversion(old_unit, value)
# old_unit = self.energy_unit
# self.__energy_unit__ = value
self._fn_energy = self.energy_unit.to(value) # get_conversion(old_unit, value)
self.__energy_unit__ = value

def set_distance_unit(self, value: str):
"""
Set a new distance unit for the dataset.
"""
#old_unit = self.distance_unit
# self.__distance_unit__ = value
self._fn_distance = self.distance_unit.to(value) #get_conversion(old_unit, value)
# old_unit = self.distance_unit
# self.__distance_unit__ = value
self._fn_distance = self.distance_unit.to(value) # get_conversion(old_unit, value)
self.__distance_unit__ = value

def set_array_format(self, format: str):
assert format in ["numpy", "torch", "jax"], f"Format {format} not supported."
self.array_format = format
Expand Down
108 changes: 50 additions & 58 deletions openqdc/utils/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,79 +8,93 @@
["ang", "nm", "bohr"]
"""

from typing import Callable
from enum import Enum, unique
from typing import Callable

from openqdc.utils.exceptions import ConversionAlreadyDefined, ConversionNotDefinedError

CONVERSION_REGISTRY = {}


# Redefined to avoid circular imports
class StrEnum(str, Enum):
def __str__(self):
return self.value.lower()


# Parent class for all conversion enums
class ConversionEnum(Enum):

@classmethod
def list(cls):
return list(map(lambda c: c.value, cls))


@unique
class EnergyTypeConversion(ConversionEnum, StrEnum):
"""
Define the possible energy units for conversion
"""

KCAL_MOL = "kcal/mol"
KJ_MOL = "kj/mol"
HARTREE = "hartree"
EV = "ev"
MEV = "mev"
RYD = "ryd"



def to(self, energy: "EnergyTypeConversion"):
return get_conversion(str(self), str(energy))


@unique
class DistanceTypeConversion(ConversionEnum, StrEnum):
"""
Define the possible distance units for conversion
"""

ANG = "ang"
NM = "nm"
BOHR = "bohr"
def to(self, distance: "DistanceTypeConversion", fraction : bool = False):

def to(self, distance: "DistanceTypeConversion", fraction: bool = False):
return get_conversion(str(self), str(distance)) if not fraction else get_conversion(str(distance), str(self))

@unique


@unique
class ForceTypeConversion(ConversionEnum):
"""
Define the possible foce units for conversion
"""

# Name = EnergyTypeConversion, , DistanceTypeConversion
HARTREE_BOHR = EnergyTypeConversion.HARTREE , DistanceTypeConversion.BOHR
HARTREE_ANG = EnergyTypeConversion.HARTREE , DistanceTypeConversion.ANG
HARTREE_NM = EnergyTypeConversion.HARTREE , DistanceTypeConversion.NM
EV_BOHR = EnergyTypeConversion.EV , DistanceTypeConversion.BOHR
EV_ANG = EnergyTypeConversion.EV , DistanceTypeConversion.ANG
EV_NM = EnergyTypeConversion.EV , DistanceTypeConversion.NM
KCAL_MOL_BOHR = EnergyTypeConversion.KCAL_MOL , DistanceTypeConversion.BOHR
KCAL_MOL_ANG = EnergyTypeConversion.KCAL_MOL , DistanceTypeConversion.ANG
KCAL_MOL_NM = EnergyTypeConversion.KCAL_MOL , DistanceTypeConversion.NM
KJ_MOL_BOHR = EnergyTypeConversion.KJ_MOL , DistanceTypeConversion.BOHR
KJ_MOL_ANG = EnergyTypeConversion.KJ_MOL , DistanceTypeConversion.ANG
KJ_MOL_NM = EnergyTypeConversion.KJ_MOL , DistanceTypeConversion.NM
MEV_BOHR = EnergyTypeConversion.MEV , DistanceTypeConversion.BOHR
MEV_ANG = EnergyTypeConversion.MEV , DistanceTypeConversion.ANG
MEV_NM = EnergyTypeConversion.MEV , DistanceTypeConversion.NM
RYD_BOHR = EnergyTypeConversion.RYD , DistanceTypeConversion.BOHR
RYD_ANG = EnergyTypeConversion.RYD , DistanceTypeConversion.ANG
RYD_NM = EnergyTypeConversion.RYD , DistanceTypeConversion.NM

def __init__(self,
energy: EnergyTypeConversion,
distance: DistanceTypeConversion):
HARTREE_BOHR = EnergyTypeConversion.HARTREE, DistanceTypeConversion.BOHR
HARTREE_ANG = EnergyTypeConversion.HARTREE, DistanceTypeConversion.ANG
HARTREE_NM = EnergyTypeConversion.HARTREE, DistanceTypeConversion.NM
EV_BOHR = EnergyTypeConversion.EV, DistanceTypeConversion.BOHR
EV_ANG = EnergyTypeConversion.EV, DistanceTypeConversion.ANG
EV_NM = EnergyTypeConversion.EV, DistanceTypeConversion.NM
KCAL_MOL_BOHR = EnergyTypeConversion.KCAL_MOL, DistanceTypeConversion.BOHR
KCAL_MOL_ANG = EnergyTypeConversion.KCAL_MOL, DistanceTypeConversion.ANG
KCAL_MOL_NM = EnergyTypeConversion.KCAL_MOL, DistanceTypeConversion.NM
KJ_MOL_BOHR = EnergyTypeConversion.KJ_MOL, DistanceTypeConversion.BOHR
KJ_MOL_ANG = EnergyTypeConversion.KJ_MOL, DistanceTypeConversion.ANG
KJ_MOL_NM = EnergyTypeConversion.KJ_MOL, DistanceTypeConversion.NM
MEV_BOHR = EnergyTypeConversion.MEV, DistanceTypeConversion.BOHR
MEV_ANG = EnergyTypeConversion.MEV, DistanceTypeConversion.ANG
MEV_NM = EnergyTypeConversion.MEV, DistanceTypeConversion.NM
RYD_BOHR = EnergyTypeConversion.RYD, DistanceTypeConversion.BOHR
RYD_ANG = EnergyTypeConversion.RYD, DistanceTypeConversion.ANG
RYD_NM = EnergyTypeConversion.RYD, DistanceTypeConversion.NM

def __init__(self, energy: EnergyTypeConversion, distance: DistanceTypeConversion):
self.energy = energy
self.distance = distance

def __str__(self):
return f"{self.energy}/{self.distance}"

def to(self, energy: EnergyTypeConversion, distance: DistanceTypeConversion):
return lambda x : self.distance.to(distance, fraction=True)(self.energy.to(energy)(x))


CONVERSION_REGISTRY = {}
def to(self, energy: EnergyTypeConversion, distance: DistanceTypeConversion):
return lambda x: self.distance.to(distance, fraction=True)(self.energy.to(energy)(x))


class Conversion:
Expand All @@ -95,11 +109,7 @@ class Conversion:
The callable to compute the conversion
"""

def __init__(self,
in_unit: str,
out_unit: str,
func: Callable[[float], float]
):
def __init__(self, in_unit: str, out_unit: str, func: Callable[[float], float]):
"""
Parameters
Expand Down Expand Up @@ -177,21 +187,3 @@ def get_conversion(in_unit: str, out_unit: str):
Conversion("nm", "ang", lambda x: x * 10.0)
Conversion("nm", "bohr", lambda x: x * 18.8973)
Conversion("bohr", "nm", lambda x: x / 18.8973)

## common forces conversion
#Conversion("hartree/bohr", "ev/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("hartree", "ev")(x)))
#Conversion("hartree/bohr", "ev/bohr", lambda x: get_conversion("hartree", "ev")(x))
#Conversion("hartree/bohr", "kcal/mol/bohr", lambda x: get_conversion("hartree", "kcal/mol")(x))
#Conversion(
# "hartree/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("hartree", "kcal/mol")(x))
#)
#Conversion("hartree/ang", "kcal/mol/ang", lambda x: get_conversion("hartree", "kcal/mol")(x))
#Conversion("hartree/ang", "hartree/bohr", lambda x: get_conversion("bohr", "ang")(x))
#Conversion("hartree/bohr", "hartree/ang", lambda x: get_conversion("ang", "bohr")(x))
#Conversion("kcal/mol/bohr", "hartree/bohr", lambda x: get_conversion("kcal/mol", "hartree")(x))
#Conversion("ev/ang", "hartree/ang", lambda x: get_conversion("ev", "hartree")(x))
#Conversion("ev/bohr", "hartree/bohr", lambda x: get_conversion("ev", "hartree")(x))
#Conversion("ev/bohr", "ev/ang", lambda x: get_conversion("ang", "bohr")(x))
#Conversion("ev/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(get_conversion("ev", "kcal/mol")(x)))
#Conversion("kcal/mol/bohr", "kcal/mol/ang", lambda x: get_conversion("ang", "bohr")(x))
#Conversion("ev/ang", "kcal/mol/ang", lambda x: get_conversion("ev", "kcal/mol")(x))
7 changes: 4 additions & 3 deletions tests/test_filedataset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from io import StringIO
import os
from io import StringIO

import numpy as np
import pytest

from openqdc.datasets.io import XYZDataset
from openqdc.methods.enums import PotentialMethod
from openqdc.utils.package_utils import has_package
from openqdc.utils.io import get_local_cache
from openqdc.utils.package_utils import has_package

if has_package("torch"):
import torch
Expand All @@ -28,7 +28,8 @@ def clean_before_run():
cache_dir = get_local_cache()
os.system(f"rm -rf {cache_dir}/XYZDataset")
yield



@pytest.fixture
def xyz_filelike():
xyz_str = """3
Expand Down

0 comments on commit b3c3b02

Please sign in to comment.