Skip to content

Commit

Permalink
Rename all _validate methods to validate
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Jun 4, 2024
1 parent b838318 commit 659b900
Show file tree
Hide file tree
Showing 14 changed files with 265 additions and 264 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- Detailed logging about invalid `nn_distances`.
- Validation of `nn_distances` passed at initialization
- Validate that passed scalars are not `nan`.
- Rename all `_validate` methods to `validate`

# v1.4.2

Expand Down
74 changes: 37 additions & 37 deletions mellon/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@
GaussianProcessType,
)
from .validation import (
_validate_positive_int,
_validate_positive_float,
_validate_float_or_int,
_validate_float,
_validate_string,
_validate_bool,
_validate_array,
_validate_float_or_iterable_numerical,
_validate_nn_distances,
validate_positive_int,
validate_positive_float,
validate_float_or_int,
validate_float,
validate_string,
validate_bool,
validate_array,
validate_float_or_iterable_numerical,
validate_nn_distances,
)
from .parameter_validation import (
_validate_params,
_validate_cov_func_curry,
_validate_cov_func,
validate_params,
validate_cov_func_curry,
validate_cov_func,
)


Expand Down Expand Up @@ -80,42 +80,42 @@ def __init__(
jit=DEFAULT_JIT,
check_rank=None,
):
self.cov_func_curry = _validate_cov_func_curry(
self.cov_func_curry = validate_cov_func_curry(
cov_func_curry, cov_func, "cov_func_curry"
)
self.n_landmarks = _validate_positive_int(
self.n_landmarks = validate_positive_int(
n_landmarks, "n_landmarks", optional=True
)
self.rank = _validate_float_or_int(rank, "rank", optional=True)
self.jitter = _validate_positive_float(jitter, "jitter")
self.landmarks = _validate_array(landmarks, "landmarks", optional=True)
self.rank = validate_float_or_int(rank, "rank", optional=True)
self.jitter = validate_positive_float(jitter, "jitter")
self.landmarks = validate_array(landmarks, "landmarks", optional=True)
self.gp_type = GaussianProcessType.from_string(gp_type, optional=True)
self.nn_distances = _validate_array(nn_distances, "nn_distances", optional=True)
self.nn_distances = _validate_nn_distances(self.nn_distances, optional=True)
self.mu = _validate_float(mu, "mu", optional=True)
self.ls = _validate_positive_float(ls, "ls", optional=True)
self.ls_factor = _validate_positive_float(ls_factor, "ls_factor")
self.cov_func = _validate_cov_func(cov_func, "cov_func", optional=True)
self.Lp = _validate_array(Lp, "Lp", optional=True)
self.L = _validate_array(L, "L", optional=True)
self.d = _validate_float_or_iterable_numerical(
self.nn_distances = validate_array(nn_distances, "nn_distances", optional=True)
self.nn_distances = validate_nn_distances(self.nn_distances, optional=True)
self.mu = validate_float(mu, "mu", optional=True)
self.ls = validate_positive_float(ls, "ls", optional=True)
self.ls_factor = validate_positive_float(ls_factor, "ls_factor")
self.cov_func = validate_cov_func(cov_func, "cov_func", optional=True)
self.Lp = validate_array(Lp, "Lp", optional=True)
self.L = validate_array(L, "L", optional=True)
self.d = validate_float_or_iterable_numerical(
d, "d", optional=True, positive=True
)
self.initial_value = _validate_array(
self.initial_value = validate_array(
initial_value, "initial_value", optional=True
)
self.optimizer = _validate_string(
self.optimizer = validate_string(
optimizer, "optimizer", choices={"adam", "advi", "L-BFGS-B"}
)
self.n_iter = _validate_positive_int(n_iter, "n_iter")
self.init_learn_rate = _validate_positive_float(
self.n_iter = validate_positive_int(n_iter, "n_iter")
self.init_learn_rate = validate_positive_float(
init_learn_rate, "init_learn_rate"
)
self.predictor_with_uncertainty = _validate_bool(
self.predictor_with_uncertainty = validate_bool(
predictor_with_uncertainty, "predictor_with_uncertainty"
)
self.jit = _validate_bool(jit, "jit")
self.check_rank = _validate_bool(check_rank, "check_rank", optional=True)
self.jit = validate_bool(jit, "jit")
self.check_rank = validate_bool(check_rank, "check_rank", optional=True)
self.x = None
self.pre_transformation = None

Expand Down Expand Up @@ -202,7 +202,7 @@ def set_x(self, x):
raise error
if x is None:
x = self.x
self.x = _validate_array(x, "x")
self.x = validate_array(x, "x")
return self.x

def _compute_n_landmarks(self):
Expand Down Expand Up @@ -242,7 +242,7 @@ def _compute_nn_distances(self):
x = self.x
logger.info("Computing nearest neighbor distances.")
nn_distances = compute_nn_distances(x)
nn_distances = _validate_nn_distances(nn_distances)
nn_distances = validate_nn_distances(nn_distances)
return nn_distances

def _compute_ls(self):
Expand Down Expand Up @@ -341,7 +341,7 @@ def _compute_L(self):

return L

def _validate_parameter(self):
def validate_parameter(self):
"""
Make sure there are no contradictions in the parameter settings.
"""
Expand All @@ -350,7 +350,7 @@ def _validate_parameter(self):
n_samples = self.x.shape[0]
n_landmarks = self.n_landmarks
landmarks = self.landmarks
_validate_params(rank, gp_type, n_samples, n_landmarks, landmarks)
validate_params(rank, gp_type, n_samples, n_landmarks, landmarks)

def _run_inference(self):
function = self.loss_func
Expand Down
46 changes: 23 additions & 23 deletions mellon/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
hessian_log_determinant,
)
from .validation import (
_validate_time_x,
_validate_float,
_validate_array,
_validate_bool,
validate_time_x,
validate_float,
validate_array,
validate_bool,
)


Expand Down Expand Up @@ -164,9 +164,9 @@ def mean(self, x, normalize=False):
If the number of features in 'x' does not match the
number of features the predictor was trained on.
"""
x = _validate_array(x, "x")
x = validate_array(x, "x")
x = ensure_2d(x)
normalize = _validate_bool(normalize, "normalize")
normalize = validate_bool(normalize, "normalize")

if x.shape[1] != self.n_input_features:
raise ValueError(
Expand Down Expand Up @@ -214,7 +214,7 @@ def covariance(self, x, diag=True):
cov : array-like, shape (n_samples, n_samples)
If diag=False, returns the full covariance matrix between samples.
"""
x = _validate_array(x, "x")
x = validate_array(x, "x")
x = ensure_2d(x)
if x.shape[1] != self.n_input_features:
raise ValueError(
Expand Down Expand Up @@ -247,7 +247,7 @@ def mean_covariance(self, x, diag=True):
cov : array-like, shape (n_samples, n_samples)
If diag=False, returns the full covariance matrix between samples.
"""
x = _validate_array(x, "x")
x = validate_array(x, "x")
x = ensure_2d(x)
if x.shape[1] != self.n_input_features:
raise ValueError(
Expand Down Expand Up @@ -278,7 +278,7 @@ def uncertainty(self, x, diag=True):
cov : array-like, shape (n_samples, n_samples) if diag=False
The full covariance matrix between the samples in the new data points.
"""
x = _validate_array(x, "x")
x = validate_array(x, "x")
x = ensure_2d(x)
if x.shape[1] != self.n_input_features:
raise ValueError(
Expand Down Expand Up @@ -309,7 +309,7 @@ def gradient(self, x, jit=True):
gradients.shape == x.shape
:rtype: array-like
"""
x = _validate_array(x, "x")
x = validate_array(x, "x")
x = ensure_2d(x)

return gradient(self._mean, x, jit=jit)
Expand All @@ -326,7 +326,7 @@ def hessian(self, x, jit=True):
hessians.shape == X.shape + X.shape[1:]
:rtype: array-like
"""
x = _validate_array(x, "x")
x = validate_array(x, "x")
x = ensure_2d(x)
return hessian(self.__call__, x, jit=jit)

Expand All @@ -344,7 +344,7 @@ def hessian_log_determinant(self, x, jit=True):
signs.shape == log_determinants.shape == x.shape[0]
:rtype: array-like, array-like
"""
x = _validate_array(x, "x")
x = validate_array(x, "x")
x = ensure_2d(x)
return hessian_log_determinant(self.__call__, x, jit=jit)

Expand Down Expand Up @@ -574,8 +574,8 @@ def mean(self, x, logscale=False):
If the number of features in 'x' does not match the
number of features the predictor was trained on.
"""
x = _validate_array(x, "x")
logscale = _validate_bool(logscale, "logscale")
x = validate_array(x, "x")
logscale = validate_bool(logscale, "logscale")
x = ensure_2d(x)
if x.shape[1] != self.n_input_features:
raise ValueError(
Expand Down Expand Up @@ -707,10 +707,10 @@ def mean(self, Xnew, time=None, normalize=False):
If 'time' is an array and its size does not match 'Xnew'.
"""
# if time is a scalar, convert it into a 1D array of the same size as Xnew
Xnew = _validate_time_x(
Xnew = validate_time_x(
Xnew, time, n_features=self.n_input_features, cast_scalar=True
)
normalize = _validate_bool(normalize, "normalize")
normalize = validate_bool(normalize, "normalize")

if normalize:
if self.n_obs is None or self.n_obs == 0:
Expand Down Expand Up @@ -756,7 +756,7 @@ def covariance(self, Xnew, time=None, diag=True):
If diag=False, returns the full covariance matrix between samples.
"""
# if time is a scalar, convert it into a 1D array of the same size as Xnew
Xnew = _validate_time_x(
Xnew = validate_time_x(
Xnew, time, n_features=self.n_input_features, cast_scalar=True
)
return self._covariance(Xnew, diag=diag)
Expand Down Expand Up @@ -788,7 +788,7 @@ def mean_covariance(self, Xnew, time=None, diag=True):
If diag=False, returns the full covariance matrix between samples.
"""
# if time is a scalar, convert it into a 1D array of the same size as Xnew
Xnew = _validate_time_x(
Xnew = validate_time_x(
Xnew, time, n_features=self.n_input_features, cast_scalar=True
)
return self._mean_covariance(Xnew, diag=diag)
Expand Down Expand Up @@ -821,7 +821,7 @@ def uncertainty(self, Xnew, time=None, diag=True):
cov : array-like, shape (n_samples, n_samples) if diag=False
The full covariance matrix between the samples in the new data points.
"""
Xnew = _validate_time_x(
Xnew = validate_time_x(
Xnew, time, n_features=self.n_input_features, cast_scalar=True
)
return self._covariance(Xnew, diag=diag) + self._mean_covariance(
Expand Down Expand Up @@ -865,7 +865,7 @@ def time_derivative(
The shape of the output array is the same as `x`.
"""
Xnew = _validate_time_x(
Xnew = validate_time_x(
x, time, n_features=self.n_input_features, cast_scalar=True
)
return super().gradient(Xnew, jit=jit)[:, -1]
Expand Down Expand Up @@ -897,7 +897,7 @@ def gradient(self, x, time, jit=True):
The gradient of the prediction function at each point in `x`.
The shape of the output array is the same as `x`.
"""
Xnew = _validate_time_x(
Xnew = validate_time_x(
x, time, n_features=self.n_input_features, cast_scalar=True
)
X, time = Xnew[:, :-1], Xnew[:, -1]
Expand Down Expand Up @@ -932,7 +932,7 @@ def hessian(self, x, time, jit=True):
The Hessian matrix of the prediction function at each point in `x`.
The shape of the output array is `x.shape + x.shape[1:]`.
"""
Xnew = _validate_time_x(
Xnew = validate_time_x(
x, time, n_features=self.n_input_features, cast_scalar=True
)
X, time = Xnew[:, :-1], Xnew[:, -1]
Expand Down Expand Up @@ -967,7 +967,7 @@ def hessian_log_determinant(self, x, time, jit=True):
`signs.shape == log_determinants.shape == x.shape[0]`.
"""

Xnew = _validate_time_x(
Xnew = validate_time_x(
x, time, n_features=self.n_input_features, cast_scalar=True
)
X, time = Xnew[:, :-1], Xnew[:, -1]
Expand Down
4 changes: 2 additions & 2 deletions mellon/compute_ls_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax.numpy.linalg import norm
from jaxopt import ScipyMinimize
from .density_estimator import DensityEstimator
from .validation import _validate_time_x
from .validation import validate_time_x

logger = logging.getLogger("mellon")

Expand Down Expand Up @@ -55,7 +55,7 @@ def compute_ls_time(
The unique time points in the training instances. Only returned if `return_data` is True.
"""
x = _validate_time_x(x, times)
x = validate_time_x(x, times)
times = x[:, -1]
states = x[:, :-1]
unique_times = unique(times)
Expand Down
12 changes: 6 additions & 6 deletions mellon/density_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
DEFAULT_JITTER,
)
from .validation import (
_validate_string,
_validate_array,
validate_string,
validate_array,
)


Expand Down Expand Up @@ -223,7 +223,7 @@ def __init__(
jit=jit,
check_rank=check_rank,
)
self.d_method = _validate_string(
self.d_method = validate_string(
d_method, "d_method", choices={"fractal", "embedding"}
)
self.transform = None
Expand Down Expand Up @@ -347,7 +347,7 @@ def prepare_inference(self, x):
self._prepare_attribute("n_landmarks")
self._prepare_attribute("rank")
self._prepare_attribute("gp_type")
self._validate_parameter()
self.validate_parameter()
self._prepare_attribute("nn_distances")
self._prepare_attribute("d")
self._prepare_attribute("mu")
Expand Down Expand Up @@ -401,7 +401,7 @@ def process_inference(self, pre_transformation=None, build_predict=True):
:rtype: array-like
"""
if pre_transformation is not None:
self.pre_transformation = _validate_array(
self.pre_transformation = validate_array(
pre_transformation, "pre_transformation"
)
self._set_log_density_x()
Expand Down Expand Up @@ -493,7 +493,7 @@ def fit_predict(self, x=None, build_predict=False):
if x is None:
x = self.x
else:
x = _validate_array(x, "x")
x = validate_array(x, "x")

self.fit(x, build_predict=build_predict)
return self.log_density_x
6 changes: 3 additions & 3 deletions mellon/derivatives.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax
from jax.numpy import isscalar, atleast_2d

from .validation import _validate_1d, _validate_float
from .validation import validate_1d, validate_float


def derivative(function, x, jit=True):
Expand Down Expand Up @@ -35,10 +35,10 @@ def get_grad(x):
return jax.jacrev(function)(x)

if isscalar(x):
x = _validate_float(x, "x")
x = validate_float(x, "x")
return get_grad(x)

x = _validate_1d(x)
x = validate_1d(x)

if jit:
get_grad = jax.jit(get_grad)
Expand Down
Loading

0 comments on commit 659b900

Please sign in to comment.