Skip to content

Commit

Permalink
added _remove_outliers args to __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
mcneela committed Feb 22, 2024
1 parent 8ed9a00 commit bc4c747
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions src/openqdc/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,27 @@ class BaseDataset:
__average_nb_atoms__ = None
__stats__ = {}

avg_options = {
"median": np.median,
"mean": np.mean,
} # options for removing outliers

def __init__(
self,
energy_unit: Optional[str] = None,
distance_unit: Optional[str] = None,
overwrite_local_cache: bool = False,
cache_dir: Optional[str] = None,
remove_outliers: bool = True,
avg_fn = "median",
num_stds: float = 3.0,
) -> None:
set_cache_dir(cache_dir)
self.data = None
self.remove_outliers = remove_outliers
self.avg_fn = avg_fn
self.num_stds = num_stds

if not self.is_preprocessed():
raise DatasetNotAvailableError(self.__name__)
else:
Expand Down Expand Up @@ -169,18 +179,19 @@ def _compute_average_nb_atoms(self):
def _remove_outliers(
self,
formation_E: np.array,
mean_or_median: str = "median",
avg_fn: str = "median",
num_stds: float = 3.0,
) -> np.array:
assert(
mean_or_median in ["mean", "median"],
f"{mean_or_median} is not a valid option, should be one of ['mean', 'median']"
avg_fn in BaseDataset.avg_options.keys(),
f"{avg_fn} is not a valid option, should be one of {list(BaseDataset.avg_options.keys())}"
)
logger.info(f"Removing outliers outside {mean_or_median} +/- {num_stds} stds")
fn = np.median if mean_or_median == "median" else np.mean
logger.info(f"Removing outliers outside {avg_fn} +/- {num_stds} stds")
fn = BaseDataset.avg_options[avg_fn]
mid = fn(formation_E)
mask = np.logical_or(formation_E < mid - num_stds * formation_E.std(), formation_E > mid + num_stds * formation_E.std())
formation_E = formation_E[~mask] # TODO: Christian, your formation E values are different than the ones I calculated yesterday, not sure why?
print(self.data.keys())
for key in self.data:
# TODO: We need a way to map the mask to the 'atomic_inputs' array
if key != "atomic_inputs":
Expand All @@ -203,7 +214,9 @@ def _precompute_E(self):

# remove outliers if requested in __init__
if self.remove_outliers:
E = self._remove_outliers(np.squeeze(E.T))
E = self._remove_outliers(np.squeeze(E.T),
avg_fn=self.avg_fn,
num_stds=self.num_stds)

inter_E_mean = np.nanmean(E / self.data["n_atoms"][:, None], axis=0)
inter_E_std = np.nanstd(E / self.data["n_atoms"][:, None], axis=0)
Expand Down

0 comments on commit bc4c747

Please sign in to comment.