From 659b900e0e160e4eefd67d8d194882e24726e568 Mon Sep 17 00:00:00 2001 From: Dominik Date: Mon, 3 Jun 2024 17:26:35 -0700 Subject: [PATCH] Rename all `_validate` methods to `validate` --- CHANGELOG.md | 1 + mellon/base_model.py | 74 +++---- mellon/base_predictor.py | 46 ++--- mellon/compute_ls_time.py | 4 +- mellon/density_estimator.py | 12 +- mellon/derivatives.py | 6 +- mellon/dimensionality_estimator.py | 18 +- mellon/function_estimator.py | 24 +-- mellon/parameter_validation.py | 28 +-- mellon/parameters.py | 42 ++-- mellon/time_sensitive_density_estimator.py | 26 +-- mellon/util.py | 4 +- mellon/validation.py | 26 +-- tests/test_validation.py | 218 ++++++++++----------- 14 files changed, 265 insertions(+), 264 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d1dd039..c4d8e7d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ - Detailed logging about invalid `nn_distances`. - Validation of `nn_distances` passed at initialization - Validate that passed scalars are not `nan`. + - Rename all `_validate` methods to `validate` # v1.4.2 diff --git a/mellon/base_model.py b/mellon/base_model.py index 19be382..cdf5697 100644 --- a/mellon/base_model.py +++ b/mellon/base_model.py @@ -27,20 +27,20 @@ GaussianProcessType, ) from .validation import ( - _validate_positive_int, - _validate_positive_float, - _validate_float_or_int, - _validate_float, - _validate_string, - _validate_bool, - _validate_array, - _validate_float_or_iterable_numerical, - _validate_nn_distances, + validate_positive_int, + validate_positive_float, + validate_float_or_int, + validate_float, + validate_string, + validate_bool, + validate_array, + validate_float_or_iterable_numerical, + validate_nn_distances, ) from .parameter_validation import ( - _validate_params, - _validate_cov_func_curry, - _validate_cov_func, + validate_params, + validate_cov_func_curry, + validate_cov_func, ) @@ -80,42 +80,42 @@ def __init__( jit=DEFAULT_JIT, check_rank=None, ): - self.cov_func_curry = _validate_cov_func_curry( + self.cov_func_curry = validate_cov_func_curry( cov_func_curry, cov_func, "cov_func_curry" ) - self.n_landmarks = _validate_positive_int( + self.n_landmarks = validate_positive_int( n_landmarks, "n_landmarks", optional=True ) - self.rank = _validate_float_or_int(rank, "rank", optional=True) - self.jitter = _validate_positive_float(jitter, "jitter") - self.landmarks = _validate_array(landmarks, "landmarks", optional=True) + self.rank = validate_float_or_int(rank, "rank", optional=True) + self.jitter = validate_positive_float(jitter, "jitter") + self.landmarks = validate_array(landmarks, "landmarks", optional=True) self.gp_type = GaussianProcessType.from_string(gp_type, optional=True) - self.nn_distances = _validate_array(nn_distances, "nn_distances", optional=True) - self.nn_distances = _validate_nn_distances(self.nn_distances, optional=True) - self.mu = _validate_float(mu, "mu", optional=True) - self.ls = _validate_positive_float(ls, "ls", optional=True) - self.ls_factor = _validate_positive_float(ls_factor, "ls_factor") - self.cov_func = _validate_cov_func(cov_func, "cov_func", optional=True) - self.Lp = _validate_array(Lp, "Lp", optional=True) - self.L = _validate_array(L, "L", optional=True) - self.d = _validate_float_or_iterable_numerical( + self.nn_distances = validate_array(nn_distances, "nn_distances", optional=True) + self.nn_distances = validate_nn_distances(self.nn_distances, optional=True) + self.mu = validate_float(mu, "mu", optional=True) + self.ls = validate_positive_float(ls, "ls", optional=True) + self.ls_factor = validate_positive_float(ls_factor, "ls_factor") + self.cov_func = validate_cov_func(cov_func, "cov_func", optional=True) + self.Lp = validate_array(Lp, "Lp", optional=True) + self.L = validate_array(L, "L", optional=True) + self.d = validate_float_or_iterable_numerical( d, "d", optional=True, positive=True ) - self.initial_value = _validate_array( + self.initial_value = validate_array( initial_value, "initial_value", optional=True ) - self.optimizer = _validate_string( + self.optimizer = validate_string( optimizer, "optimizer", choices={"adam", "advi", "L-BFGS-B"} ) - self.n_iter = _validate_positive_int(n_iter, "n_iter") - self.init_learn_rate = _validate_positive_float( + self.n_iter = validate_positive_int(n_iter, "n_iter") + self.init_learn_rate = validate_positive_float( init_learn_rate, "init_learn_rate" ) - self.predictor_with_uncertainty = _validate_bool( + self.predictor_with_uncertainty = validate_bool( predictor_with_uncertainty, "predictor_with_uncertainty" ) - self.jit = _validate_bool(jit, "jit") - self.check_rank = _validate_bool(check_rank, "check_rank", optional=True) + self.jit = validate_bool(jit, "jit") + self.check_rank = validate_bool(check_rank, "check_rank", optional=True) self.x = None self.pre_transformation = None @@ -202,7 +202,7 @@ def set_x(self, x): raise error if x is None: x = self.x - self.x = _validate_array(x, "x") + self.x = validate_array(x, "x") return self.x def _compute_n_landmarks(self): @@ -242,7 +242,7 @@ def _compute_nn_distances(self): x = self.x logger.info("Computing nearest neighbor distances.") nn_distances = compute_nn_distances(x) - nn_distances = _validate_nn_distances(nn_distances) + nn_distances = validate_nn_distances(nn_distances) return nn_distances def _compute_ls(self): @@ -341,7 +341,7 @@ def _compute_L(self): return L - def _validate_parameter(self): + def validate_parameter(self): """ Make sure there are no contradictions in the parameter settings. """ @@ -350,7 +350,7 @@ def _validate_parameter(self): n_samples = self.x.shape[0] n_landmarks = self.n_landmarks landmarks = self.landmarks - _validate_params(rank, gp_type, n_samples, n_landmarks, landmarks) + validate_params(rank, gp_type, n_samples, n_landmarks, landmarks) def _run_inference(self): function = self.loss_func diff --git a/mellon/base_predictor.py b/mellon/base_predictor.py index f72604f..3e6c5a8 100644 --- a/mellon/base_predictor.py +++ b/mellon/base_predictor.py @@ -27,10 +27,10 @@ hessian_log_determinant, ) from .validation import ( - _validate_time_x, - _validate_float, - _validate_array, - _validate_bool, + validate_time_x, + validate_float, + validate_array, + validate_bool, ) @@ -164,9 +164,9 @@ def mean(self, x, normalize=False): If the number of features in 'x' does not match the number of features the predictor was trained on. """ - x = _validate_array(x, "x") + x = validate_array(x, "x") x = ensure_2d(x) - normalize = _validate_bool(normalize, "normalize") + normalize = validate_bool(normalize, "normalize") if x.shape[1] != self.n_input_features: raise ValueError( @@ -214,7 +214,7 @@ def covariance(self, x, diag=True): cov : array-like, shape (n_samples, n_samples) If diag=False, returns the full covariance matrix between samples. """ - x = _validate_array(x, "x") + x = validate_array(x, "x") x = ensure_2d(x) if x.shape[1] != self.n_input_features: raise ValueError( @@ -247,7 +247,7 @@ def mean_covariance(self, x, diag=True): cov : array-like, shape (n_samples, n_samples) If diag=False, returns the full covariance matrix between samples. """ - x = _validate_array(x, "x") + x = validate_array(x, "x") x = ensure_2d(x) if x.shape[1] != self.n_input_features: raise ValueError( @@ -278,7 +278,7 @@ def uncertainty(self, x, diag=True): cov : array-like, shape (n_samples, n_samples) if diag=False The full covariance matrix between the samples in the new data points. """ - x = _validate_array(x, "x") + x = validate_array(x, "x") x = ensure_2d(x) if x.shape[1] != self.n_input_features: raise ValueError( @@ -309,7 +309,7 @@ def gradient(self, x, jit=True): gradients.shape == x.shape :rtype: array-like """ - x = _validate_array(x, "x") + x = validate_array(x, "x") x = ensure_2d(x) return gradient(self._mean, x, jit=jit) @@ -326,7 +326,7 @@ def hessian(self, x, jit=True): hessians.shape == X.shape + X.shape[1:] :rtype: array-like """ - x = _validate_array(x, "x") + x = validate_array(x, "x") x = ensure_2d(x) return hessian(self.__call__, x, jit=jit) @@ -344,7 +344,7 @@ def hessian_log_determinant(self, x, jit=True): signs.shape == log_determinants.shape == x.shape[0] :rtype: array-like, array-like """ - x = _validate_array(x, "x") + x = validate_array(x, "x") x = ensure_2d(x) return hessian_log_determinant(self.__call__, x, jit=jit) @@ -574,8 +574,8 @@ def mean(self, x, logscale=False): If the number of features in 'x' does not match the number of features the predictor was trained on. """ - x = _validate_array(x, "x") - logscale = _validate_bool(logscale, "logscale") + x = validate_array(x, "x") + logscale = validate_bool(logscale, "logscale") x = ensure_2d(x) if x.shape[1] != self.n_input_features: raise ValueError( @@ -707,10 +707,10 @@ def mean(self, Xnew, time=None, normalize=False): If 'time' is an array and its size does not match 'Xnew'. """ # if time is a scalar, convert it into a 1D array of the same size as Xnew - Xnew = _validate_time_x( + Xnew = validate_time_x( Xnew, time, n_features=self.n_input_features, cast_scalar=True ) - normalize = _validate_bool(normalize, "normalize") + normalize = validate_bool(normalize, "normalize") if normalize: if self.n_obs is None or self.n_obs == 0: @@ -756,7 +756,7 @@ def covariance(self, Xnew, time=None, diag=True): If diag=False, returns the full covariance matrix between samples. """ # if time is a scalar, convert it into a 1D array of the same size as Xnew - Xnew = _validate_time_x( + Xnew = validate_time_x( Xnew, time, n_features=self.n_input_features, cast_scalar=True ) return self._covariance(Xnew, diag=diag) @@ -788,7 +788,7 @@ def mean_covariance(self, Xnew, time=None, diag=True): If diag=False, returns the full covariance matrix between samples. """ # if time is a scalar, convert it into a 1D array of the same size as Xnew - Xnew = _validate_time_x( + Xnew = validate_time_x( Xnew, time, n_features=self.n_input_features, cast_scalar=True ) return self._mean_covariance(Xnew, diag=diag) @@ -821,7 +821,7 @@ def uncertainty(self, Xnew, time=None, diag=True): cov : array-like, shape (n_samples, n_samples) if diag=False The full covariance matrix between the samples in the new data points. """ - Xnew = _validate_time_x( + Xnew = validate_time_x( Xnew, time, n_features=self.n_input_features, cast_scalar=True ) return self._covariance(Xnew, diag=diag) + self._mean_covariance( @@ -865,7 +865,7 @@ def time_derivative( The shape of the output array is the same as `x`. """ - Xnew = _validate_time_x( + Xnew = validate_time_x( x, time, n_features=self.n_input_features, cast_scalar=True ) return super().gradient(Xnew, jit=jit)[:, -1] @@ -897,7 +897,7 @@ def gradient(self, x, time, jit=True): The gradient of the prediction function at each point in `x`. The shape of the output array is the same as `x`. """ - Xnew = _validate_time_x( + Xnew = validate_time_x( x, time, n_features=self.n_input_features, cast_scalar=True ) X, time = Xnew[:, :-1], Xnew[:, -1] @@ -932,7 +932,7 @@ def hessian(self, x, time, jit=True): The Hessian matrix of the prediction function at each point in `x`. The shape of the output array is `x.shape + x.shape[1:]`. """ - Xnew = _validate_time_x( + Xnew = validate_time_x( x, time, n_features=self.n_input_features, cast_scalar=True ) X, time = Xnew[:, :-1], Xnew[:, -1] @@ -967,7 +967,7 @@ def hessian_log_determinant(self, x, time, jit=True): `signs.shape == log_determinants.shape == x.shape[0]`. """ - Xnew = _validate_time_x( + Xnew = validate_time_x( x, time, n_features=self.n_input_features, cast_scalar=True ) X, time = Xnew[:, :-1], Xnew[:, -1] diff --git a/mellon/compute_ls_time.py b/mellon/compute_ls_time.py index 4ac4d1e..3abaefb 100644 --- a/mellon/compute_ls_time.py +++ b/mellon/compute_ls_time.py @@ -4,7 +4,7 @@ from jax.numpy.linalg import norm from jaxopt import ScipyMinimize from .density_estimator import DensityEstimator -from .validation import _validate_time_x +from .validation import validate_time_x logger = logging.getLogger("mellon") @@ -55,7 +55,7 @@ def compute_ls_time( The unique time points in the training instances. Only returned if `return_data` is True. """ - x = _validate_time_x(x, times) + x = validate_time_x(x, times) times = x[:, -1] states = x[:, :-1] unique_times = unique(times) diff --git a/mellon/density_estimator.py b/mellon/density_estimator.py index 289a2bc..3d9503e 100644 --- a/mellon/density_estimator.py +++ b/mellon/density_estimator.py @@ -20,8 +20,8 @@ DEFAULT_JITTER, ) from .validation import ( - _validate_string, - _validate_array, + validate_string, + validate_array, ) @@ -223,7 +223,7 @@ def __init__( jit=jit, check_rank=check_rank, ) - self.d_method = _validate_string( + self.d_method = validate_string( d_method, "d_method", choices={"fractal", "embedding"} ) self.transform = None @@ -347,7 +347,7 @@ def prepare_inference(self, x): self._prepare_attribute("n_landmarks") self._prepare_attribute("rank") self._prepare_attribute("gp_type") - self._validate_parameter() + self.validate_parameter() self._prepare_attribute("nn_distances") self._prepare_attribute("d") self._prepare_attribute("mu") @@ -401,7 +401,7 @@ def process_inference(self, pre_transformation=None, build_predict=True): :rtype: array-like """ if pre_transformation is not None: - self.pre_transformation = _validate_array( + self.pre_transformation = validate_array( pre_transformation, "pre_transformation" ) self._set_log_density_x() @@ -493,7 +493,7 @@ def fit_predict(self, x=None, build_predict=False): if x is None: x = self.x else: - x = _validate_array(x, "x") + x = validate_array(x, "x") self.fit(x, build_predict=build_predict) return self.log_density_x diff --git a/mellon/derivatives.py b/mellon/derivatives.py index 439b237..f5468cc 100644 --- a/mellon/derivatives.py +++ b/mellon/derivatives.py @@ -1,7 +1,7 @@ import jax from jax.numpy import isscalar, atleast_2d -from .validation import _validate_1d, _validate_float +from .validation import validate_1d, validate_float def derivative(function, x, jit=True): @@ -35,10 +35,10 @@ def get_grad(x): return jax.jacrev(function)(x) if isscalar(x): - x = _validate_float(x, "x") + x = validate_float(x, "x") return get_grad(x) - x = _validate_1d(x) + x = validate_1d(x) if jit: get_grad = jax.jit(get_grad) diff --git a/mellon/dimensionality_estimator.py b/mellon/dimensionality_estimator.py index 35ef7a7..7ba596b 100644 --- a/mellon/dimensionality_estimator.py +++ b/mellon/dimensionality_estimator.py @@ -22,9 +22,9 @@ object_str, ) from .validation import ( - _validate_positive_int, - _validate_float, - _validate_array, + validate_positive_int, + validate_float, + validate_array, ) @@ -221,10 +221,10 @@ def __init__( jit=jit, check_rank=check_rank, ) - self.k = _validate_positive_int(k, "k") - self.mu_dim = _validate_float(mu_dim, "mu_dim") - self.mu_dens = _validate_float(mu_dens, "mu_dens", optional=True) - self.distances = _validate_array(distances, "distances", optional=True) + self.k = validate_positive_int(k, "k") + self.mu_dim = validate_float(mu_dim, "mu_dim") + self.mu_dens = validate_float(mu_dens, "mu_dens", optional=True) + self.distances = validate_array(distances, "distances", optional=True) self.transform = None self.loss_func = None self.opt_state = None @@ -419,7 +419,7 @@ def prepare_inference(self, x): self._prepare_attribute("n_landmarks") self._prepare_attribute("rank") self._prepare_attribute("gp_type") - self._validate_parameter() + self.validate_parameter() self._prepare_attribute("distances") self._prepare_attribute("nn_distances") self._prepare_attribute("d") @@ -596,7 +596,7 @@ def fit_predict(self, x=None, build_predict=False): if x is None: x = self.x else: - x = _validate_array(x, "x") + x = validate_array(x, "x") self.fit(x, build_predict=build_predict) return self.local_dim_x diff --git a/mellon/function_estimator.py b/mellon/function_estimator.py index 151ce50..f886d0c 100644 --- a/mellon/function_estimator.py +++ b/mellon/function_estimator.py @@ -12,10 +12,10 @@ object_str, ) from .validation import ( - _validate_float_or_iterable_numerical, - _validate_float, - _validate_array, - _validate_bool, + validate_float_or_iterable_numerical, + validate_float, + validate_array, + validate_bool, ) @@ -145,9 +145,9 @@ def __init__( predictor_with_uncertainty=predictor_with_uncertainty, jit=jit, ) - self.y_is_mean = _validate_bool(y_is_mean, "y_is_mean") - self.mu = _validate_float(mu, "mu") - self.sigma = _validate_float_or_iterable_numerical( + self.y_is_mean = validate_bool(y_is_mean, "y_is_mean") + self.mu = validate_float(mu, "mu") + self.sigma = validate_float_or_iterable_numerical( sigma, "sigma", positive=True ) if ( @@ -248,7 +248,7 @@ def compute_conditional(self, x=None, y=None): if x is None: x = self.x else: - x = _validate_array(x, "x") + x = validate_array(x, "x") if self.x is not None and self.x is not x: message = ( "self.x has been set already, but is not equal to the argument x. " @@ -312,7 +312,7 @@ def fit(self, x=None, y=None): This method returns self for chaining. """ x = self.set_x(x) - y = _validate_array(y, "y") + y = validate_array(y, "y") n_samples = x.shape[0] # Check if the number of samples in x and y match @@ -382,8 +382,8 @@ def fit_predict(self, x=None, y=None, Xnew=None): `Xnew` don't match. """ x = self.set_x(x) - y = _validate_array(y, "y") - Xnew = _validate_array(Xnew, "Xnew", optional=True) + y = validate_array(y, "y") + Xnew = validate_array(Xnew, "Xnew", optional=True) # If Xnew is not provided, default to x if Xnew is None: @@ -443,7 +443,7 @@ def multi_fit_predict(self, x=None, Y=None, Xnew=None): # Set the x and validate inputs x = self.set_x(x) - Y = _validate_array(Y, "Y") + Y = validate_array(Y, "Y") n_samples = x.shape[0] diff --git a/mellon/parameter_validation.py b/mellon/parameter_validation.py index 958262f..4d54ddb 100644 --- a/mellon/parameter_validation.py +++ b/mellon/parameter_validation.py @@ -3,14 +3,14 @@ from .util import GaussianProcessType from .base_cov import Covariance from .validation import ( - _validate_positive_int, - _validate_float_or_int, + validate_positive_int, + validate_float_or_int, ) logger = logging.getLogger("mellon") -def _validate_landmark_params(n_landmarks, landmarks): +def validate_landmark_params(n_landmarks, landmarks): """ Validates that n_landmarks and landmarks are compatible. @@ -31,7 +31,7 @@ def _validate_landmark_params(n_landmarks, landmarks): raise ValueError(message) -def _validate_rank_params(gp_type, n_samples, rank, n_landmarks): +def validate_rank_params(gp_type, n_samples, rank, n_landmarks): """ Validates that rank, n_landmarks, and gp_type are compatible. @@ -93,7 +93,7 @@ def _validate_rank_params(gp_type, n_samples, rank, n_landmarks): raise ValueError(message) -def _validate_gp_type(gp_type, n_samples, n_landmarks): +def validate_gp_type(gp_type, n_samples, n_landmarks): """ Validates that gp_type, n_samples, and n_landmarks are compatible. @@ -146,7 +146,7 @@ def _validate_gp_type(gp_type, n_samples, n_landmarks): raise ValueError(message) -def _validate_params(rank, gp_type, n_samples, n_landmarks, landmarks): +def validate_params(rank, gp_type, n_samples, n_landmarks, landmarks): """ Validates that rank, gp_type, n_samples, n_landmarks, and landmarks are compatible. @@ -168,8 +168,8 @@ def _validate_params(rank, gp_type, n_samples, n_landmarks, landmarks): The given landmarks/inducing points. """ - n_landmarks = _validate_positive_int(n_landmarks, "n_landmarks") - rank = _validate_float_or_int(rank, "rank") + n_landmarks = validate_positive_int(n_landmarks, "n_landmarks") + rank = validate_float_or_int(rank, "rank") if not isinstance(gp_type, GaussianProcessType): message = ( @@ -180,19 +180,19 @@ def _validate_params(rank, gp_type, n_samples, n_landmarks, landmarks): raise ValueError(message) # Validation logic for landmarks - _validate_landmark_params(n_landmarks, landmarks) + validate_landmark_params(n_landmarks, landmarks) if n_landmarks > n_samples: logger.warning( f"n_landmarks={n_landmarks:,} is larger than the number of cells {n_samples:,}." ) - _validate_gp_type(gp_type, n_samples, n_landmarks) + validate_gp_type(gp_type, n_samples, n_landmarks) # Validation logic for rank - _validate_rank_params(gp_type, n_samples, rank, n_landmarks) + validate_rank_params(gp_type, n_samples, rank, n_landmarks) -def _validate_cov_func_curry(cov_func_curry, cov_func, param_name): +def validate_cov_func_curry(cov_func_curry, cov_func, param_name): """ Validates covariance function curry type. @@ -229,7 +229,7 @@ def _validate_cov_func_curry(cov_func_curry, cov_func, param_name): return cov_func_curry -def _validate_cov_func(cov_func, param_name, optional=False): +def validate_cov_func(cov_func, param_name, optional=False): """ Validates an instance of a covariance function. @@ -263,7 +263,7 @@ def _validate_cov_func(cov_func, param_name, optional=False): return cov_func -def _validate_normalize_parameter(normalize, unique_times): +def validate_normalize_parameter(normalize, unique_times): """ Used in parameters.compute_nn_distances_within_time_points to validate input. """ diff --git a/mellon/parameters.py b/mellon/parameters.py index 272742e..4f5c3ca 100644 --- a/mellon/parameters.py +++ b/mellon/parameters.py @@ -35,15 +35,15 @@ DEFAULT_SIGMA, ) from .validation import ( - _validate_time_x, - _validate_positive_float, - _validate_float_or_int, - _validate_positive_int, - _validate_float_or_iterable_numerical, + validate_time_x, + validate_positive_float, + validate_float_or_int, + validate_positive_int, + validate_float_or_iterable_numerical, ) from .parameter_validation import ( - _validate_params, - _validate_normalize_parameter, + validate_params, + validate_normalize_parameter, ) @@ -188,9 +188,9 @@ def compute_gp_type(n_landmarks, rank, n_samples): One of the Gaussian Process types defined in the `GaussianProcessType` Enum. """ - rank = _validate_float_or_int(rank, "rank", optional=True) - n_landmarks = _validate_positive_int(n_landmarks, "n_landmarks") - n_samples = _validate_positive_int(n_samples, "n_samples") + rank = validate_float_or_int(rank, "rank", optional=True) + n_landmarks = validate_positive_int(n_landmarks, "n_landmarks") + n_samples = validate_positive_int(n_samples, "n_samples") if n_landmarks == 0 or n_landmarks >= n_samples: # Full model @@ -313,9 +313,9 @@ def compute_landmarks_rescale_time( if n_landmarks == 0: return None - ls = _validate_positive_float(ls, "ls") - ls_time = _validate_positive_float(ls_time, "ls_time") - x = _validate_time_x(x, times) + ls = validate_positive_float(ls, "ls") + ls_time = validate_positive_float(ls_time, "ls_time") + x = validate_time_x(x, times) time_factor = ls / ls_time x = x.at[:, -1].set(x[:, -1] * time_factor) landmarks = compute_landmarks(x, n_landmarks=n_landmarks) @@ -424,16 +424,16 @@ def compute_nn_distances_within_time_points(x, times=None, d=None, normalize=Fal preserving the order of instances in `x`. """ - x = _validate_time_x(x, times) + x = validate_time_x(x, times) unique_times = unique(x[:, -1]) nn_distances = empty(x.shape[0]) n_cells = x.shape[0] av_cells_per_tp = n_cells / len(unique_times) - _validate_normalize_parameter(normalize, unique_times) + validate_normalize_parameter(normalize, unique_times) if normalize is not False and normalize is not None: - d = _validate_float_or_iterable_numerical(d, "d", optional=False, positive=True) + d = validate_float_or_iterable_numerical(d, "d", optional=False, positive=True) if ndim(d) > 0 and len(d) != x.shape[0]: ld = len(d) raise ValueError( @@ -650,7 +650,7 @@ def compute_Lp( raise ValueError(message) -def _validate_compute_L_input(x, cov_func, gp_type, landmarks, Lp, rank, sigma, jitter): +def validate_compute_L_input(x, cov_func, gp_type, landmarks, Lp, rank, sigma, jitter): """ Validate input for the fuction compute_L. @@ -672,8 +672,8 @@ def _validate_compute_L_input(x, cov_func, gp_type, landmarks, Lp, rank, sigma, ValueError If any of the inputs are inconsistent or violate constraints. """ - jitter = _validate_positive_float(jitter, "jitter") - rank = _validate_float_or_int(rank, "rank", optional=True) + jitter = validate_positive_float(jitter, "jitter") + rank = validate_float_or_int(rank, "rank", optional=True) n_samples = x.shape[0] if landmarks is None: @@ -685,7 +685,7 @@ def _validate_compute_L_input(x, cov_func, gp_type, landmarks, Lp, rank, sigma, rank = compute_rank(gp_type) if gp_type is None: gp_type = compute_gp_type(n_landmarks, rank, n_samples) - _validate_params(rank, gp_type, n_samples, n_landmarks, landmarks) + validate_params(rank, gp_type, n_samples, n_landmarks, landmarks) if ( gp_type == GaussianProcessType.FULL @@ -773,7 +773,7 @@ def compute_L( ValueError If the Gaussian Process type is unknown or if the shape of Lp is incorrect. """ - x, landmarks, n_landmarks, n_samples, gp_type, rank = _validate_compute_L_input( + x, landmarks, n_landmarks, n_samples, gp_type, rank = validate_compute_L_input( x, cov_func, gp_type, landmarks, Lp, rank, sigma, jitter ) diff --git a/mellon/time_sensitive_density_estimator.py b/mellon/time_sensitive_density_estimator.py index c8bdccf..5ac43b7 100644 --- a/mellon/time_sensitive_density_estimator.py +++ b/mellon/time_sensitive_density_estimator.py @@ -27,11 +27,11 @@ object_str, ) from .validation import ( - _validate_time_x, - _validate_positive_float, - _validate_string, - _validate_array, - _validate_nn_distances, + validate_time_x, + validate_positive_float, + validate_string, + validate_array, + validate_nn_distances, ) @@ -287,11 +287,11 @@ def __init__( if not isinstance(density_estimator_kwargs, dict): raise ValueError("density_estimator_kwargs needs to be a dictionary.") self.density_estimator_kwargs = density_estimator_kwargs - self.d_method = _validate_string( + self.d_method = validate_string( d_method, "d_method", choices={"fractal", "embedding"} ) - self.ls_time = _validate_positive_float(ls_time, "ls_time", optional=True) - self.ls_time_factor = _validate_positive_float(ls_time_factor, "ls_time_factor") + self.ls_time = validate_positive_float(ls_time, "ls_time", optional=True) + self.ls_time_factor = validate_positive_float(ls_time_factor, "ls_time_factor") self._save_intermediate_ls_times = _save_intermediate_ls_times self.normalize_per_time_point = normalize_per_time_point self.transform = None @@ -398,7 +398,7 @@ def _compute_nn_distances(self): d=d, normalize=normalize_per_time_point, ) - nn_distances = _validate_nn_distances(nn_distances) + nn_distances = validate_nn_distances(nn_distances) return nn_distances def _compute_ls(self): @@ -546,7 +546,7 @@ def prepare_inference(self, x, times=None): message = "Required argument x is missing and self.x has not been set." raise ValueError(message) else: - x = _validate_time_x(x, times) + x = validate_time_x(x, times) if self.x is not None and self.x is not x: message = ( "self.x has been set already, but is not equal to the argument x." @@ -557,7 +557,7 @@ def prepare_inference(self, x, times=None): self._prepare_attribute("n_landmarks") self._prepare_attribute("rank") self._prepare_attribute("gp_type") - self._validate_parameter() + self.validate_parameter() self._prepare_attribute("d") self._prepare_attribute("nn_distances") self._prepare_attribute("mu") @@ -612,7 +612,7 @@ def process_inference(self, pre_transformation=None, build_predict=True): :rtype: array-like """ if pre_transformation is not None: - self.pre_transformation = _validate_array( + self.pre_transformation = validate_array( pre_transformation, "pre_transformation" ) self._set_log_density_x() @@ -686,7 +686,7 @@ def fit_predict(self, x=None, times=None, build_predict=False): :return: log_density_x - The log density at each training point in x. """ if x is not None: - x = _validate_time_x(x, times) + x = validate_time_x(x, times) if self.x is not None and x is not None and self.x is not x: message = "self.x has been set already, but is not equal to the argument x." error = ValueError(message) diff --git a/mellon/util.py b/mellon/util.py index 7a41827..5a4ebff 100644 --- a/mellon/util.py +++ b/mellon/util.py @@ -40,7 +40,7 @@ from jax import vmap, jit from sklearn.neighbors import BallTree, KDTree -from .validation import _validate_array +from .validation import validate_array logger = logging.getLogger("mellon") @@ -250,7 +250,7 @@ def wrapper(self, *args, **kwargs): raise ValueError( "Cannot specify both 'time' and 'multi_time' arguments" ) - multi_time = _validate_array(multi_time, "multi_time") + multi_time = validate_array(multi_time, "multi_time") def at_time(t): return func(self, *args, **kwargs, time=t) diff --git a/mellon/validation.py b/mellon/validation.py index 253cd47..a0929a5 100644 --- a/mellon/validation.py +++ b/mellon/validation.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) -def _validate_time_x(x, times=None, n_features=None, cast_scalar=False): +def validate_time_x(x, times=None, n_features=None, cast_scalar=False): """ Validates and concatenates 'x' and 'times' if 'times' is provided. @@ -51,14 +51,14 @@ def _validate_time_x(x, times=None, n_features=None, cast_scalar=False): of features. """ - x = _validate_array(x, "x", ndim=2) + x = validate_array(x, "x", ndim=2) if ( cast_scalar and times is not None and (isscalar(times) or all(s == 1 for s in times.shape)) ): times = full(x.shape[0], times) - times = _validate_array(times, "times", optional=True, ndim=(1, 2)) + times = validate_array(times, "times", optional=True, ndim=(1, 2)) if times is not None: # Validate 'times' shape @@ -91,7 +91,7 @@ def _validate_time_x(x, times=None, n_features=None, cast_scalar=False): return x -def _validate_float_or_int(value, param_name, optional=False): +def validate_float_or_int(value, param_name, optional=False): """ Validates whether a given value is a float or an integer, and not nan. @@ -132,7 +132,7 @@ def _validate_float_or_int(value, param_name, optional=False): return value -def _validate_positive_float(value, param_name, optional=False): +def validate_positive_float(value, param_name, optional=False): """ Validates whether a given value is a positive float, and non-NaN. @@ -175,7 +175,7 @@ def _validate_positive_float(value, param_name, optional=False): return value -def _validate_float(value, param_name, optional=False): +def validate_float(value, param_name, optional=False): """ Validates if the input is a float or can be converted to a float. @@ -223,7 +223,7 @@ def _validate_float(value, param_name, optional=False): return value -def _validate_positive_int(value, param_name, optional=False): +def validate_positive_int(value, param_name, optional=False): """ Validates whether a given value is a positive integer. @@ -254,7 +254,7 @@ def _validate_positive_int(value, param_name, optional=False): return value -def _validate_array(iterable, name, optional=False, ndim=None): +def validate_array(iterable, name, optional=False, ndim=None): """ Validates and converts an iterable to a numpy array of type float. Allows Jax's JVPTracer objects and avoids explicit conversion in these cases. @@ -316,7 +316,7 @@ def _validate_array(iterable, name, optional=False, ndim=None): return array -def _validate_bool(value, name, optional=False): +def validate_bool(value, name, optional=False): """ Validates whether a given value is a boolean. @@ -355,7 +355,7 @@ def _validate_bool(value, name, optional=False): return value -def _validate_string(value, name, choices=None): +def validate_string(value, name, choices=None): """ Validates whether a given value is a string and optionally whether it is in a set of choices. @@ -390,7 +390,7 @@ def _validate_string(value, name, choices=None): return value -def _validate_float_or_iterable_numerical(value, name, optional=False, positive=False): +def validate_float_or_iterable_numerical(value, name, optional=False, positive=False): """ Validates whether a given value is a float, integer, or iterable of numerical values, with an option to check for non-negativity. @@ -439,7 +439,7 @@ def _validate_float_or_iterable_numerical(value, name, optional=False, positive= ) -def _validate_1d(x): +def validate_1d(x): """ Validates that `x` can be cast to a JAX array with exactly 1 dimension and float data type. @@ -470,7 +470,7 @@ def _validate_1d(x): return x -def _validate_nn_distances(nn_distances, optional=False): +def validate_nn_distances(nn_distances, optional=False): """ Validates and corrects nearest neighbor distances. Ensures all distances are positive and handles invalid values. diff --git a/tests/test_validation.py b/tests/test_validation.py index de7c0d4..b1f9733 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,22 +1,22 @@ import pytest import jax.numpy as jnp from mellon.validation import ( - _validate_time_x, - _validate_float_or_int, - _validate_positive_float, - _validate_float, - _validate_array, - _validate_bool, - _validate_string, - _validate_float_or_iterable_numerical, - _validate_positive_int, - _validate_1d, - _validate_nn_distances, + validate_time_x, + validate_float_or_int, + validate_positive_float, + validate_float, + validate_array, + validate_bool, + validate_string, + validate_float_or_iterable_numerical, + validate_positive_int, + validate_1d, + validate_nn_distances, ) from mellon.parameter_validation import ( - _validate_params, - _validate_cov_func_curry, - _validate_cov_func, + validate_params, + validate_cov_func_curry, + validate_cov_func, ) from mellon.cov import Covariance from mellon.util import GaussianProcessType @@ -54,228 +54,228 @@ (100, GaussianProcessType.SPARSE_NYSTROEM, 100, 50, None, ValueError), ], ) -def test_validate_params( +def testvalidate_params( rank, gp_type, n_samples, n_landmarks, landmarks, exception_expected ): if exception_expected: with pytest.raises(exception_expected): - _validate_params(rank, gp_type, n_samples, n_landmarks, landmarks) + validate_params(rank, gp_type, n_samples, n_landmarks, landmarks) else: - _validate_params(rank, gp_type, n_samples, n_landmarks, landmarks) + validate_params(rank, gp_type, n_samples, n_landmarks, landmarks) -def test_validate_float_or_int(): +def testvalidate_float_or_int(): # Test with integer input - assert _validate_float_or_int(10, "param") == 10 + assert validate_float_or_int(10, "param") == 10 # Test with float input - assert _validate_float_or_int(10.5, "param") == 10.5 + assert validate_float_or_int(10.5, "param") == 10.5 # Test with string input with pytest.raises(ValueError): - _validate_float_or_int("string", "param") + validate_float_or_int("string", "param") # Test with None input and optional=True - assert _validate_float_or_int(None, "param", optional=True) is None + assert validate_float_or_int(None, "param", optional=True) is None # Test with None input and optional=False with pytest.raises(ValueError): - _validate_float_or_int(None, "param", optional=False) + validate_float_or_int(None, "param", optional=False) # Test with nan value with pytest.raises(ValueError): - _validate_float_or_int(jnp.nan, "param") + validate_float_or_int(jnp.nan, "param") -def test_validate_positive_float(): +def testvalidate_positive_float(): # Test with positive float input - assert _validate_positive_float(10.5, "param") == 10.5 + assert validate_positive_float(10.5, "param") == 10.5 # Test with negative float input with pytest.raises(ValueError): - _validate_positive_float(-10.5, "param") + validate_positive_float(-10.5, "param") # Test with positive integer input - assert _validate_positive_float(10, "param") == 10.0 + assert validate_positive_float(10, "param") == 10.0 # Test with negative integer input with pytest.raises(ValueError): - _validate_positive_float(-10, "param") + validate_positive_float(-10, "param") # Test with string input with pytest.raises(ValueError): - _validate_positive_float("string", "param") + validate_positive_float("string", "param") # Test with None input and optional=True - assert _validate_positive_float(None, "param", optional=True) is None + assert validate_positive_float(None, "param", optional=True) is None # Test with None input and optional=False with pytest.raises(ValueError): - _validate_positive_float(None, "param", optional=False) + validate_positive_float(None, "param", optional=False) # Test with nan value with pytest.raises(ValueError): - _validate_positive_float(jnp.nan, "param") + validate_positive_float(jnp.nan, "param") -def test_validate_positive_int(): +def testvalidate_positive_int(): # Test with positive integer input - assert _validate_positive_int(10, "param") == 10 + assert validate_positive_int(10, "param") == 10 # Test with negative integer input with pytest.raises(ValueError): - _validate_positive_int(-10, "param") + validate_positive_int(-10, "param") # Test with float input with pytest.raises(ValueError): - _validate_positive_int(10.5, "param") + validate_positive_int(10.5, "param") # Test with None input and optional=True - assert _validate_positive_int(None, "param", optional=True) is None + assert validate_positive_int(None, "param", optional=True) is None # Test with None input and optional=False with pytest.raises(ValueError): - _validate_positive_int(None, "param", optional=False) + validate_positive_int(None, "param", optional=False) # Test with nan value with pytest.raises(ValueError): - _validate_positive_int(jnp.nan, "param") + validate_positive_int(jnp.nan, "param") -def test_validate_array(): +def testvalidate_array(): # Test with array-like input array = jnp.array([1, 2, 3]) - validated_array = _validate_array(array, "param", ndim=1) + validated_array = validate_array(array, "param", ndim=1) assert jnp.array_equal(validated_array, array) # Test with non-array input with pytest.raises(TypeError): - _validate_array(10, "param") + validate_array(10, "param") # Test with None input and optional=True - assert _validate_array(None, "param", optional=True) is None + assert validate_array(None, "param", optional=True) is None # Test with None input and optional=False with pytest.raises(TypeError): - _validate_array(None, "param", optional=False) + validate_array(None, "param", optional=False) # Test with incorrect number of dimensions with pytest.raises(ValueError): - _validate_array(array, "param", ndim=2) + validate_array(array, "param", ndim=2) -def test_validate_bool(): +def testvalidate_bool(): # Test with bool input - assert _validate_bool(True, "param") is True + assert validate_bool(True, "param") is True # Test with non-bool input with pytest.raises(TypeError): - _validate_bool(10, "param") + validate_bool(10, "param") -def test_validate_string(): +def testvalidate_string(): # Test with string input - assert _validate_string("test", "param") == "test" + assert validate_string("test", "param") == "test" # Test with non-string input with pytest.raises(TypeError): - _validate_string(10, "param") + validate_string(10, "param") # Test with invalid choice with pytest.raises(ValueError): - _validate_string("invalid", "param", choices=["valid", "test"]) + validate_string("invalid", "param", choices=["valid", "test"]) -def test_validate_float_or_iterable_numerical(): +def testvalidate_float_or_iterable_numerical(): # Test with positive numbers - assert _validate_float_or_iterable_numerical(5, "value") == 5.0 + assert validate_float_or_iterable_numerical(5, "value") == 5.0 assert jnp.allclose( - _validate_float_or_iterable_numerical([5, 6], "value"), jnp.asarray([5.0, 6.0]) + validate_float_or_iterable_numerical([5, 6], "value"), jnp.asarray([5.0, 6.0]) ) # Test with negative numbers, without positive constraint - assert _validate_float_or_iterable_numerical(-5, "value") == -5.0 + assert validate_float_or_iterable_numerical(-5, "value") == -5.0 assert jnp.allclose( - _validate_float_or_iterable_numerical([-5, -6], "value"), + validate_float_or_iterable_numerical([-5, -6], "value"), jnp.asarray([-5.0, -6.0]), ) # Test with zero - assert _validate_float_or_iterable_numerical(0, "value") == 0.0 + assert validate_float_or_iterable_numerical(0, "value") == 0.0 # Test with positive=True - assert _validate_float_or_iterable_numerical(5, "value", positive=True) == 5.0 + assert validate_float_or_iterable_numerical(5, "value", positive=True) == 5.0 # Test with negative numbers and positive=True with pytest.raises(ValueError): - _validate_float_or_iterable_numerical(-5, "value", positive=True) + validate_float_or_iterable_numerical(-5, "value", positive=True) with pytest.raises(ValueError): - _validate_float_or_iterable_numerical([-5, 6], "value", positive=True) + validate_float_or_iterable_numerical([-5, 6], "value", positive=True) # Test with None and optional=True - assert _validate_float_or_iterable_numerical(None, "value", optional=True) is None + assert validate_float_or_iterable_numerical(None, "value", optional=True) is None # Test with None and optional=False with pytest.raises(TypeError): - _validate_float_or_iterable_numerical(None, "value", optional=False) + validate_float_or_iterable_numerical(None, "value", optional=False) # Test with non-numeric types with pytest.raises(TypeError): - _validate_float_or_iterable_numerical("string", "value") + validate_float_or_iterable_numerical("string", "value") with pytest.raises(ValueError): - _validate_float_or_iterable_numerical(["string"], "value") + validate_float_or_iterable_numerical(["string"], "value") # Test with mixed numeric and non-numeric iterable with pytest.raises(ValueError): - _validate_float_or_iterable_numerical([5, "string"], "value") + validate_float_or_iterable_numerical([5, "string"], "value") -def test_validate_time_x(): +def testvalidate_time_x(): # Test with only 'x' and no 'times' or 'n_features' x = jnp.array([[1, 2], [3, 4], [5, 6]]) - result = _validate_time_x(x) + result = validate_time_x(x) assert jnp.array_equal(result, x) # Test with 'x' and 'times' times = jnp.array([1, 2, 3]) expected_result = jnp.array([[1, 2, 1], [3, 4, 2], [5, 6, 3]]) - result = _validate_time_x(x, times) + result = validate_time_x(x, times) assert jnp.array_equal(result, expected_result) # Test with 'x' and 'times' with shape (n_samples, 1) times = jnp.array([[1], [2], [3]]) - result = _validate_time_x(x, times) + result = validate_time_x(x, times) assert jnp.array_equal(result, expected_result) # Test with 'x' and 'times' but mismatched number of samples times = jnp.array([1, 2]) with pytest.raises(ValueError): - _validate_time_x(x, times) + validate_time_x(x, times) # Test with 'x' and 'times' but 'times' is not 1D or 2D with 1 column times = jnp.array([[1, 1], [2, 2], [3, 3]]) with pytest.raises(ValueError): - _validate_time_x(x, times) + validate_time_x(x, times) # Test with 'x', 'times', and 'n_features' correct times = jnp.array([1, 2, 3]) - result = _validate_time_x(x, times, n_features=3) + result = validate_time_x(x, times, n_features=3) assert jnp.array_equal(result, expected_result) # Test with 'x', 'times', and 'n_features' incorrect with pytest.raises(ValueError): - _validate_time_x(x, times, n_features=2) + validate_time_x(x, times, n_features=2) # Test with 'x', no 'times', and 'n_features' incorrect with pytest.raises(ValueError): - _validate_time_x(x, n_features=3) + validate_time_x(x, n_features=3) # Test with scalar 'times' and 'cast_scalar' set to True times = 1 expected_result = jnp.array([[1, 2, 1], [3, 4, 1], [5, 6, 1]]) - result = _validate_time_x(x, times, cast_scalar=True) + result = validate_time_x(x, times, cast_scalar=True) assert jnp.array_equal(result, expected_result) @@ -287,125 +287,125 @@ def k(self): pass -def test_validate_cov_func_curry(): +def testvalidate_cov_func_curry(): # Test with both parameters as None with pytest.raises(ValueError): - _validate_cov_func_curry(None, None, "cov_func_curry") + validate_cov_func_curry(None, None, "cov_func_curry") # Test with valid covariance function curry cov_func_curry = CustomCovariance - result = _validate_cov_func_curry(cov_func_curry, None, "cov_func_curry") + result = validate_cov_func_curry(cov_func_curry, None, "cov_func_curry") assert result == cov_func_curry # Test with invalid covariance function curry cov_func_curry = "Invalid" with pytest.raises(ValueError): - _validate_cov_func_curry(cov_func_curry, None, "cov_func_curry") + validate_cov_func_curry(cov_func_curry, None, "cov_func_curry") -def test_validate_cov_func(): +def testvalidate_cov_func(): # Test with valid covariance function cov_func = CustomCovariance() - result = _validate_cov_func(cov_func, "cov_func") + result = validate_cov_func(cov_func, "cov_func") assert result == cov_func # Test with invalid covariance function cov_func = "Invalid" with pytest.raises(ValueError): - _validate_cov_func(cov_func, "cov_func") + validate_cov_func(cov_func, "cov_func") # Test with None as optional - result = _validate_cov_func(None, "cov_func", True) + result = validate_cov_func(None, "cov_func", True) assert result is None # Test with None as not optional with pytest.raises(ValueError): - _validate_cov_func(None, "cov_func", False) + validate_cov_func(None, "cov_func", False) -def test_validate_1d(): +def testvalidate_1d(): # Test with valid 1D array arr = jnp.array([1.2, 2.3, 3.4]) - result = _validate_1d(arr) + result = validate_1d(arr) assert jnp.allclose(result, arr), "Arrays not equal" # Test with a scalar scalar = 2.3 - result = _validate_1d(scalar) + result = validate_1d(scalar) assert jnp.allclose(result, jnp.array([scalar])), "Arrays not equal" # Test with 2D array arr = jnp.array([[1.2, 2.3, 3.4], [4.5, 5.6, 6.7]]) with pytest.raises(ValueError): - _validate_1d(arr) + validate_1d(arr) # Test with string, should raise an error due to dtype mismatch string = "invalid" with pytest.raises(ValueError): - _validate_1d(string) + validate_1d(string) -def test_validate_float(): +def testvalidate_float(): # Test with valid float - result = _validate_float(1.5, "param1") + result = validate_float(1.5, "param1") assert result == 1.5 # Test with valid int - result = _validate_float(2, "param1") + result = validate_float(2, "param1") assert result == 2.0 # Test with 1x1 array - result = _validate_float(jnp.array([2.0]), "param1") + result = validate_float(jnp.array([2.0]), "param1") assert result == 2.0 # Test with None and optional - result = _validate_float(None, "param1", optional=True) + result = validate_float(None, "param1", optional=True) assert result is None # Test with None and not optional with pytest.raises(ValueError): - _validate_float(None, "param1") + validate_float(None, "param1") # Test with invalid type with pytest.raises(ValueError): - _validate_float("not a float", "param1") + validate_float("not a float", "param1") # Test with invalid type (non-numeric) with pytest.raises(ValueError): - _validate_float([1, 2, 3], "param1") + validate_float([1, 2, 3], "param1") # Test with nan value with pytest.raises(ValueError): - _validate_float(jnp.nan, "param1") + validate_float(jnp.nan, "param1") -def test_validate_nn_distances(): +def testvalidate_nn_distances(): # Test with all valid distances nn_distances = jnp.array([0.1, 0.5, 1.2, 0.3]) - result = _validate_nn_distances(nn_distances) + result = validate_nn_distances(nn_distances) assert jnp.all(result == nn_distances), "Valid distances should not be changed." # Test with NaN values nn_distances = jnp.array([0.1, jnp.nan, 1.2, 0.3]) - result = _validate_nn_distances(nn_distances) + result = validate_nn_distances(nn_distances) assert jnp.any(result != nn_distances), "NaN values should be replaced." assert not jnp.isnan(result).any(), "Result should not contain NaN values." # Test with infinite values nn_distances = jnp.array([0.1, jnp.inf, 1.2, 0.3]) - result = _validate_nn_distances(nn_distances) + result = validate_nn_distances(nn_distances) assert jnp.any(result != nn_distances), "Infinite values should be replaced." assert not jnp.isinf(result).any(), "Result should not contain infinite values." # Test with negative values nn_distances = jnp.array([0.1, -0.5, 1.2, 0.3]) - result = _validate_nn_distances(nn_distances) + result = validate_nn_distances(nn_distances) assert jnp.any(result != nn_distances), "Negative values should be replaced." assert jnp.all(result > 0), "Result should not contain negative values." # Test with a mix of invalid values nn_distances = jnp.array([0.1, jnp.nan, jnp.inf, -0.5, 1.2]) - result = _validate_nn_distances(nn_distances) + result = validate_nn_distances(nn_distances) assert jnp.any(result != nn_distances), "Invalid values should be replaced." assert not jnp.isnan(result).any(), "Result should not contain NaN values." assert not jnp.isinf(result).any(), "Result should not contain infinite values." @@ -414,13 +414,13 @@ def test_validate_nn_distances(): # Test with all invalid values nn_distances = jnp.array([jnp.nan, jnp.inf, -0.5]) with pytest.raises(ValueError): - _validate_nn_distances(nn_distances) + validate_nn_distances(nn_distances) # Test with optional=True and nn_distances=None assert ( - _validate_nn_distances(None, optional=True) is None + validate_nn_distances(None, optional=True) is None ), "Should return None if optional is True and nn_distances is None." # Test with optional=False and nn_distances=None with pytest.raises(ValueError): - _validate_nn_distances(None, optional=False) + validate_nn_distances(None, optional=False)