From bc4c747669a2f66f2391cd0d89028356e4852d2d Mon Sep 17 00:00:00 2001 From: Danny McNeela Date: Thu, 22 Feb 2024 15:36:58 +0000 Subject: [PATCH] added _remove_outliers args to __init__ --- src/openqdc/datasets/base.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/openqdc/datasets/base.py b/src/openqdc/datasets/base.py index 458da0b..4bbe705 100644 --- a/src/openqdc/datasets/base.py +++ b/src/openqdc/datasets/base.py @@ -97,6 +97,11 @@ class BaseDataset: __average_nb_atoms__ = None __stats__ = {} + avg_options = { + "median": np.median, + "mean": np.mean, + } # options for removing outliers + def __init__( self, energy_unit: Optional[str] = None, @@ -104,10 +109,15 @@ def __init__( overwrite_local_cache: bool = False, cache_dir: Optional[str] = None, remove_outliers: bool = True, + avg_fn = "median", + num_stds: float = 3.0, ) -> None: set_cache_dir(cache_dir) self.data = None self.remove_outliers = remove_outliers + self.avg_fn = avg_fn + self.num_stds = num_stds + if not self.is_preprocessed(): raise DatasetNotAvailableError(self.__name__) else: @@ -169,18 +179,19 @@ def _compute_average_nb_atoms(self): def _remove_outliers( self, formation_E: np.array, - mean_or_median: str = "median", + avg_fn: str = "median", num_stds: float = 3.0, ) -> np.array: assert( - mean_or_median in ["mean", "median"], - f"{mean_or_median} is not a valid option, should be one of ['mean', 'median']" + avg_fn in BaseDataset.avg_options.keys(), + f"{avg_fn} is not a valid option, should be one of {list(BaseDataset.avg_options.keys())}" ) - logger.info(f"Removing outliers outside {mean_or_median} +/- {num_stds} stds") - fn = np.median if mean_or_median == "median" else np.mean + logger.info(f"Removing outliers outside {avg_fn} +/- {num_stds} stds") + fn = BaseDataset.avg_options[avg_fn] mid = fn(formation_E) mask = np.logical_or(formation_E < mid - num_stds * formation_E.std(), formation_E > mid + num_stds * formation_E.std()) formation_E = formation_E[~mask] # TODO: Christian, your formation E values are different than the ones I calculated yesterday, not sure why? + print(self.data.keys()) for key in self.data: # TODO: We need a way to map the mask to the 'atomic_inputs' array if key != "atomic_inputs": @@ -203,7 +214,9 @@ def _precompute_E(self): # remove outliers if requested in __init__ if self.remove_outliers: - E = self._remove_outliers(np.squeeze(E.T)) + E = self._remove_outliers(np.squeeze(E.T), + avg_fn=self.avg_fn, + num_stds=self.num_stds) inter_E_mean = np.nanmean(E / self.data["n_atoms"][:, None], axis=0) inter_E_std = np.nanstd(E / self.data["n_atoms"][:, None], axis=0)