Skip to content

Commit

Permalink
better distance validation
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Jun 3, 2024
1 parent 6627774 commit f311bc9
Show file tree
Hide file tree
Showing 6 changed files with 689 additions and 32 deletions.
15 changes: 2 additions & 13 deletions mellon/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_validate_bool,
_validate_array,
_validate_float_or_iterable_numerical,
_validate_nn_distances,
)
from .parameter_validation import (
_validate_params,
Expand Down Expand Up @@ -240,19 +241,7 @@ def _compute_nn_distances(self):
x = self.x
logger.info("Computing nearest neighbor distances.")
nn_distances = compute_nn_distances(x)

# Check for invalid values
nan_count = np.isnan(nn_distances).sum()
inf_count = np.isinf(nn_distances).sum()
negative_count = (nn_distances < 0).sum()
if nan_count > 0 or inf_count > 0 or negative_count > 0:
total_invalid = nan_count + inf_count + negative_count
logger.warning(
"The computed nearest neighbor distances (`nn_distances` attribute) contain "
f"{total_invalid:,} invalid values: "
f"{nan_count:,} NaN, {inf_count:,} infinite, {negative_count:,} less than 0. "
"Please check the input data."
)
nn_distances = _validate_nn_distances(nn_distances)
return nn_distances

def _compute_ls(self):
Expand Down
Loading

0 comments on commit f311bc9

Please sign in to comment.