Skip to content

Commit

Permalink
Fix on incorrect unit changing, stats calculated on original units, c…
Browse files Browse the repository at this point in the history
…onversion on the fly
  • Loading branch information
FNTwin committed Nov 29, 2023
1 parent 52f69ce commit f38bda1
Showing 1 changed file with 62 additions and 10 deletions.
72 changes: 62 additions & 10 deletions src/openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,34 @@ def __init__(
) -> None:
set_cache_dir(cache_dir)
self.data = None
self._set_units(energy_unit, distance_unit)
if not self.is_preprocessed():
raise DatasetNotAvailableError(self.__name__)
else:
self.read_preprocess(overwrite_local_cache=overwrite_local_cache)
self._post_init(overwrite_local_cache, energy_unit, distance_unit)

def _post_init(
self,
overwrite_local_cache: bool = False,
energy_unit: Optional[str] = None,
distance_unit: Optional[str] = None,
) -> None:
self._set_units(None, None)
self._set_isolated_atom_energies()
self._precompute_statistics(overwrite_local_cache=overwrite_local_cache)
self._set_units(energy_unit, distance_unit)
self._convert_data()
self._set_isolated_atom_energies()
self._precompute_statistics()

def _convert_data(self):
logger.info(
f"Converting {self.__name__} data to the following units:\n\
Energy: {self.energy_unit},\n\
Distance: {self.distance_unit},\n\
Forces: {self.force_unit if self.__force_methods__ else 'None'}"
)
for key in self.data_keys:
self.data[key] = self._convert_on_loading(self.data[key], key)

def _precompute_statistics(self, overwrite_local_cache: bool = False):
local_path = p_join(self.preprocess_path, "stats.pkl")
Expand Down Expand Up @@ -153,7 +174,12 @@ def _precompute_E(self):
total_E_mean = np.nanmean(converted_energy_data, axis=0)
total_E_std = np.nanstd(converted_energy_data, axis=0)

return formation_E_mean, formation_E_std, total_E_mean, total_E_std
return (
np.atleast_2d(formation_E_mean),
np.atleast_2d(formation_E_std),
np.atleast_2d(total_E_mean),
np.atleast_2d(total_E_std),
)

def _precompute_F(self):
if len(self.__force_methods__) == 0:
Expand All @@ -163,16 +189,16 @@ def _precompute_F(self):
force_std = np.nanstd(converted_force_data, axis=0)
force_rms = np.sqrt(np.nanmean(converted_force_data**2, axis=0))
return {
"mean": force_mean,
"std": force_std,
"components": {"rms": force_rms, "std": force_std.mean(axis=0), "mean": force_mean.mean(axis=0)},
"mean": np.atleast_2d(force_mean.mean(axis=0)),
"std": np.atleast_2d(force_std.mean(axis=0)),
"components": {"rms": force_rms, "std": force_std, "mean": force_mean},
}

@property
def numbers(self):
if hasattr(self, "_numbers"):
return self._numbers
self._numbers = np.unique(self.data["atomic_inputs"][..., 0]).astype(np.int32)
self._numbers = pd.unique(self.data["atomic_inputs"][..., 0]).astype(np.int32)
return self._numbers

@property
Expand Down Expand Up @@ -302,6 +328,18 @@ def save_preprocess(self, data_dict):
pkl.dump(data_dict, f)
push_remote(local_path, overwrite=True)

def _convert_on_loading(self, x, key):
if key == "energies":
return self.convert_energy(x)
elif key == "forces":
return self.convert_forces(x)
elif key == "atomic_inputs":
x = np.array(x, dtype=np.float32)
x[:, -3:] = self.convert_distance(x[:, -3:])
return x
else:
return x

def read_preprocess(self, overwrite_local_cache=False):
logger.info("Reading preprocessed data")
logger.info(
Expand Down Expand Up @@ -473,14 +511,14 @@ def __getitem__(self, idx: int):
z, c, positions, energies = (
np.array(input[:, 0], dtype=np.int32),
np.array(input[:, 1], dtype=np.int32),
self.convert_distance(np.array(input[:, -3:], dtype=np.float32)),
self.convert_energy(np.array(self.data["energies"][idx], dtype=np.float32)),
np.array(input[:, -3:], dtype=np.float32),
np.array(self.data["energies"][idx], dtype=np.float32),
)
name = self.__smiles_converter__(self.data["name"][idx])
subset = self.data["subset"][idx]

if "forces" in self.data:
forces = self.convert_forces(np.array(self.data["forces"][p_start:p_end], dtype=np.float32))
forces = np.array(self.data["forces"][p_start:p_end], dtype=np.float32)
else:
forces = None
return Bunch(
Expand Down Expand Up @@ -542,4 +580,18 @@ def get_statistics(self, normalization: str = "formation", return_none: bool = T
}
}
)
# cycle trough dict to convert units
for key in selected_stats:
if key == "forces":
for key2 in selected_stats[key]:
if key2 != "components":
selected_stats[key][key2] = self.convert_forces(selected_stats[key][key2])
else:
for key2 in selected_stats[key]["components"]:
selected_stats[key]["components"][key2] = self.convert_forces(
selected_stats[key]["components"][key2]
)
else:
for key2 in selected_stats[key]:
selected_stats[key][key2] = self.convert_energy(selected_stats[key][key2])
return selected_stats

0 comments on commit f38bda1

Please sign in to comment.