Skip to content

Commit

Permalink
Units enum, fixes, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FNTwin committed Jun 21, 2024
1 parent b3c3b02 commit 462cdbe
Show file tree
Hide file tree
Showing 4 changed files with 528 additions and 10 deletions.
12 changes: 6 additions & 6 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def distance_unit(self):

@property
def force_unit(self):
return ForceTypeConversion(*self.__forces_unit__.split("/"))
units = self.__forces_unit__.split("/")
if len(units) > 2:
units = ["/".join(units[:2]), units[-1]]
return ForceTypeConversion(*units)

@property
def root(self):
Expand Down Expand Up @@ -296,17 +299,14 @@ def data_shapes(self):
"forces": (-1, 3, len(self.force_methods)),
}

def _set_units(self, en, ds):
def _set_units(self, en: Optional[str] = None, ds: Optional[str] = None):
old_en, old_ds = self.energy_unit, self.distance_unit
en = en if en is not None else old_en
ds = ds if ds is not None else old_ds
self.set_energy_unit(en)
self.set_distance_unit(ds)
if self.__force_methods__:
# self.__forces_unit__ = str(self.energy_unit) + "/" + str(self.distance_unit)
self._fn_forces = self.force_unit.to(
str(self.energy_unit), str(self.distance_unit)
) # get_conversion(old_en + "/" + old_ds, self.__forces_unit__)
self._fn_forces = self.force_unit.to(str(self.energy_unit), str(self.distance_unit))
self.__forces_unit__ = str(self.energy_unit) + "/" + str(self.distance_unit)

def _set_isolated_atom_energies(self):
Expand Down
2 changes: 1 addition & 1 deletion openqdc/datasets/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(
self.recompute_statistics = True
self.refit_e0s = True
self.energy_type = energy_type
self._original_unit = energy_unit
self.__energy_unit__ = energy_unit
self._original_unit = self.energy_unit
self.__distance_unit__ = distance_unit
self.__energy_methods__ = [PotentialMethod.NONE if not level_of_theory else level_of_theory]
self.energy_target_names = ["xyz"]
Expand Down
4 changes: 1 addition & 3 deletions openqdc/utils/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def __str__(self):

# Parent class for all conversion enums
class ConversionEnum(Enum):
@classmethod
def list(cls):
return list(map(lambda c: c.value, cls))
pass


@unique
Expand Down
Loading

0 comments on commit 462cdbe

Please sign in to comment.