Skip to content

Commit

Permalink
simplified component-wise-force stats calculation and bug-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikhil Shenoy committed Apr 5, 2024
1 parent ac299a5 commit 40d900d
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 66 deletions.
22 changes: 9 additions & 13 deletions openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
FormationEnergyStats,
PerAtomFormationEnergyStats,
StatisticManager,
StatisticsResults,
TotalEnergyStats,
)
from openqdc.utils.constants import MAX_CHARGE, NB_ATOMIC_FEATURES
Expand Down Expand Up @@ -528,25 +529,20 @@ def get_statistics(self, return_none: bool = True):
"ForcesCalculatorStats": {
"mean": np.array([0.0]),
"std": np.array([0.0]),
"components": {
"mean": np.array([[0.0], [0.0], [0.0]]),
"std": np.array([[0.0], [0.0], [0.0]]),
"rms": np.array([[0.0], [0.0], [0.0]]),
},
"component_mean": np.array([[0.0], [0.0], [0.0]]),
"component_std": np.array([[0.0], [0.0], [0.0]]),
"component_rms": np.array([[0.0], [0.0], [0.0]]),
}
}
)
# cycle trough dict to convert units
for key in selected_stats:
if key.lower() == str(ForcesCalculatorStats):
if isinstance(selected_stats[key], StatisticsResults):
selected_stats[key] = selected_stats[key].to_dict()

if key.lower() == ForcesCalculatorStats.__name__.lower():
for key2 in selected_stats[key]:
if key2 != "components":
selected_stats[key][key2] = self.convert_forces(selected_stats[key][key2])
else:
for key2 in selected_stats[key]["components"]:
selected_stats[key]["components"][key2] = self.convert_forces(
selected_stats[key]["components"][key2]
)
selected_stats[key][key2] = self.convert_forces(selected_stats[key][key2])
else:
for key2 in selected_stats[key]:
selected_stats[key][key2] = self.convert_energy(selected_stats[key][key2])
Expand Down
24 changes: 2 additions & 22 deletions openqdc/datasets/potential/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from openqdc.datasets.base import BaseDataset
from openqdc.methods import PotentialMethod
from openqdc.utils.constants import NOT_DEFINED


class Dummy(BaseDataset):
Expand All @@ -11,8 +10,8 @@ class Dummy(BaseDataset):
"""

__name__ = "dummy"
__energy_methods__ = [PotentialMethod.SVWN_DEF2_TZVP, PotentialMethod.PM6]
__force_mask__ = [False, True]
__energy_methods__ = [PotentialMethod.SVWN_DEF2_TZVP, PotentialMethod.PM6, PotentialMethod.GFN2_XTB]
__force_mask__ = [False, True, True]
__energy_unit__ = "kcal/mol"
__distance_unit__ = "ang"
__forces_unit__ = "kcal/mol/ang"
Expand All @@ -23,25 +22,6 @@ class Dummy(BaseDataset):
__isolated_atom_energies__ = []
__average_n_atoms__ = None

@property
def _stats(self):
return {
"formation": {
"energy": {
"mean": np.array([[-12.94348027, -9.83037297]]),
"std": np.array([[4.39971409, 3.3574188]]),
},
"forces": NOT_DEFINED,
},
"total": {
"energy": {
"mean": np.array([[-89.44242, -1740.5336]]),
"std": np.array([[29.599571, 791.48663]]),
},
"forces": NOT_DEFINED,
},
}

def _post_init(self, overwrite_local_cache, energy_unit, distance_unit) -> None:
self.setup_dummy()
return super()._post_init(overwrite_local_cache, energy_unit, distance_unit)
Expand Down
38 changes: 15 additions & 23 deletions openqdc/datasets/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,6 @@ class EnergyStatistics(StatisticsResults):
std: Optional[np.ndarray]


@dataclass
class ForceComponentsStatistics(StatisticsResults):
"""
Dataclass for force statistics related to the x,y,z components
mean,std,rms are supposed to be 2d arrays related to the x,y,z components
of the forces
"""

mean: Optional[np.ndarray]
std: Optional[np.ndarray]
rms: Optional[np.ndarray]


@dataclass
class ForceStatistics(StatisticsResults):
"""
Expand All @@ -57,7 +44,9 @@ class ForceStatistics(StatisticsResults):

mean: Optional[np.ndarray]
std: Optional[np.ndarray]
components: ForceComponentsStatistics
component_mean: Optional[np.ndarray]
component_std: Optional[np.ndarray]
component_rms: Optional[np.ndarray]


class StatisticManager:
Expand Down Expand Up @@ -266,17 +255,20 @@ class ForcesCalculatorStats(AbstractStatsCalculator):

def compute(self) -> ForceStatistics:
if not self.has_forces:
return ForceStatistics(
mean=None, std=None, components=ForceComponentsStatistics(rms=None, std=None, mean=None)
)
return ForceStatistics(mean=None, std=None, component_mean=None, component_std=None, component_rms=None)
converted_force_data = self.forces
force_mean = np.nanmean(converted_force_data, axis=0)
force_std = np.nanstd(converted_force_data, axis=0)
force_rms = np.sqrt(np.nanmean(converted_force_data**2, axis=0))
num_methods = converted_force_data.shape[2]
mean = np.nanmean(converted_force_data.reshape(-1, num_methods), axis=0)
std = np.nanstd(converted_force_data.reshape(-1, num_methods), axis=0)
component_mean = np.nanmean(converted_force_data, axis=0)
component_std = np.nanstd(converted_force_data, axis=0)
component_rms = np.sqrt(np.nanmean(converted_force_data**2, axis=0))
return ForceStatistics(
mean=np.atleast_2d(force_mean),
std=np.atleast_2d(force_std),
components=ForceComponentsStatistics(rms=force_rms, std=force_std, mean=force_mean),
mean=np.atleast_2d(mean),
std=np.atleast_2d(std),
component_mean=np.atleast_2d(component_mean),
component_std=np.atleast_2d(component_std),
component_rms=np.atleast_2d(component_rms),
)


Expand Down
8 changes: 3 additions & 5 deletions openqdc/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@
NOT_DEFINED = {
"mean": None,
"std": None,
"components": {
"mean": None,
"std": None,
"rms": None,
},
"component_mean": None,
"component_std": None,
"component_rms": None,
}

ATOM_TABLE = Chem.GetPeriodicTable()
Expand Down
56 changes: 54 additions & 2 deletions tests/test_dummy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
"""Path hack to make tests work."""

import os

import numpy as np
import pytest

from openqdc.datasets.potential.dummy import Dummy # noqa: E402
from openqdc.utils.io import get_local_cache
from openqdc.utils.package_utils import has_package

# start by removing any cached data
cache_dir = get_local_cache()
os.system(f"rm -rf {cache_dir}/dummy")


if has_package("torch"):
import torch

Expand All @@ -19,8 +27,12 @@
}


def test_dummy():
ds = Dummy()
@pytest.fixture
def ds():
return Dummy()


def test_dummy(ds):
assert len(ds) > 10
assert ds[100]

Expand Down Expand Up @@ -68,3 +80,43 @@ def custom_fn(bunch):

assert "new_key" in data
assert data["new_key"] == data["name"] + data["subset"]


def test_get_statistics(ds):
stats = ds.get_statistics()

keys = ["ForcesCalculatorStats", "FormationEnergyStats", "PerAtomFormationEnergyStats", "TotalEnergyStats"]
assert all(k in stats for k in keys)


def test_energy_statistics_shapes(ds):
stats = ds.get_statistics()

num_methods = len(ds.energy_methods)

formation_energy_stats = stats["FormationEnergyStats"]
assert formation_energy_stats["mean"].shape == (1, num_methods)
assert formation_energy_stats["std"].shape == (1, num_methods)

per_atom_formation_energy_stats = stats["PerAtomFormationEnergyStats"]
assert per_atom_formation_energy_stats["mean"].shape == (1, num_methods)
assert per_atom_formation_energy_stats["std"].shape == (1, num_methods)

total_energy_stats = stats["TotalEnergyStats"]
assert total_energy_stats["mean"].shape == (1, num_methods)
assert total_energy_stats["std"].shape == (1, num_methods)


def test_force_statistics_shapes(ds):
stats = ds.get_statistics()
num_force_methods = len(ds.force_methods)

forces_stats = stats["ForcesCalculatorStats"]
keys = ["mean", "std", "component_mean", "component_std", "component_rms"]
assert all(k in forces_stats for k in keys)

assert forces_stats["mean"].shape == (1, num_force_methods)
assert forces_stats["std"].shape == (1, num_force_methods)
assert forces_stats["component_mean"].shape == (3, num_force_methods)
assert forces_stats["component_std"].shape == (3, num_force_methods)
assert forces_stats["component_rms"].shape == (3, num_force_methods)
3 changes: 2 additions & 1 deletion tests/test_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def test_regressors(small_dummy):
setattr(reg, "solver_type", solver_type)
reg.solver = reg._get_solver()
assert isinstance(reg.solver, inst)
num_methods = len(small_dummy.energy_methods)
try:
results = reg.solve()
assert results[0].shape[1] == 2
assert results[0].shape[1] == num_methods
except np.linalg.LinAlgError:
pass

0 comments on commit 40d900d

Please sign in to comment.