diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..9e2cc1a --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,87 @@ +# v1.4.0 +## New Features +### `with_uncertainty` Parameter +Integrates a boolean parameter `with_uncertainty` across all estimators: [DensityEstimator](https://mellon.readthedocs.io/en/uncertainty/model.html#mellon.model.DensityEstimator), TimeSensitiveDensityEstimator, FunctionEstimator, and DimensionalityEstimator. It modifies the fitted predictor, accessible via the `.predict` property, to include the following methods: + - `.covariance(X)`: Calculates the (co-)variance of the posterior Gaussian Process (GP). + - Almost 0 near landmarks; grows for out-of-sample locations. + - Increases with sparsity. + - Defaults to `diag=True`, computing only the covariance matrix diagonal. + - `.mean_covariance(X)`: Computes the (co-)variance through the uncertainty of the mean function's GP posterior. + - Derived from Bayesian inference for latent density function representation. + - Increases in low data or low-density areas. + - Only available with posterior uncertainty quantification, e.g., `optimizer='advi'` except for the `FunctionEstimator` where input uncertainty is specified through the `sigma` parameter. + - Defaults to `diag=True`, computing only the covariance matrix diagonal. + - `.uncertainty(X)`: Combines `.covariance(X)` and `.mean_covariance(X)`. + - Defaults to `diag=True`, computing only the covariance matrix diagonal. + - Square root provides standard deviation. + +### `gp_type` Parameter +Introduces the `gp_type` parameter to all relevant [estimators](https://mellon.readthedocs.io/en/uncertainty/model.html) to explicitly specify the Gaussian Process (GP) sparsification strategy, replacing the previously used `method` argument (with options auto, fixed, and percent) that implicitly controlled sparsification. The available options for `gp_type` include: + - 'full': Non-sparse GP. + - 'full_nystroem': Sparse GP with Nyström rank reduction, lowering computational complexity. + - 'sparse_cholesky': SParse GP using landmarks/inducing points. + - 'sparse_nystroem': Improved Nyström rank reduction on sparse GP with landmarks, balancing accuracy and efficiency. + +This new parameter adds additional validation steps, ensuring that no contradictory parameters are specified. If inconsistencies are detected, a helpful reply guides the user on how to fix the issue. The value can be either a string matching one of the options above or an instance of the `mellon.parameters.GaussianProcessType` Enum. Partial matches log a warning, using the closest match. Defaults to 'sparse_cholesky'. + +*Note: Nyström strategies are not applicable to the **FunctionEstimator**.* + +### `y_is_mean` Parameter +Adds a boolean parameter `y_is_mean` to [FunctionEstimator](https://mellon.readthedocs.io/en/uncertainty/model.html#mellon.model.FunctionEstimator), affecting how `y` values are interpreted: +- **Old Behavior**: `sigma` impacted conditional mean functions and predictions. +- **Intermediate Behavior**: `sigma` only influenced prediction uncertainty. +- **New Parameter**: If `y_is_mean=True`, `y` values are treated as a fixed mean; `sigma` reflects only uncertainty. If `y_is_mean=False`, `y` is considered a noisy measurement, potentially smoothing values at locations `x`. + +This change benefits DensityEstimator, TimeSensitiveDensityEstimator, and DimensionalityEstimator where function values are predicted for out-of-sample locations after mean GP computation. + +### `check_rank` Parameter +Introduces the `check_rank ` parameter to all relevant [estimators](https://mellon.readthedocs.io/en/uncertainty/model.html). This boolean parameter explicitly controls whether the rank check is performed, specifically in the `gp_type="sparse_cholesky"` case. The rank check assesses the chosen landmarks for adequate complexity by examining the approximate rank of the covariance matrix, issuing a warning if insufficient. Allowed values are: + - `True`: Always perform the check. + - `False`: Never perform the check. + - `None` (Default): Perform the check only if `n_landmarks` is greater than or equal to `n_samples` divided by 10. + +The default setting aims to bypass unnecessary computation when the number of landmarks is so abundant that insufficient complexity becomes improbable. + +### `normalize` Parameter + +The `normalize` parameter is applicable to both the [`.mean`](https://mellon.readthedocs.io/en/uncertainty/serialization.html#mellon.Predictor.mean) method and `.__call__` method within the [mellon.Predictor](https://mellon.readthedocs.io/en/uncertainty/serialization.html#predictor-class) class. When set to `True`, these methods will subtract `log(number of observations)` from the value returned. This feature is particularly useful with the [DensityEstimator](https://mellon.readthedocs.io/en/uncertainty/model.html#mellon.model.DensityEstimator), where normalization adjusts for the number of cells in the training sample, allowing for accurate density comparisons between datasets. This correction takes into account the effect of dataset size, ensuring that differences in total cell numbers are not unduly influential. By default, the parameter is set to `False`, meaning that density differences due to variations in total cell number will remain uncorrected. + +### `normalize_per_time_point` Parameter + +This parameter fine-tunes the `TimeSensitiveDensityEstimator` to handle variations in sampling bias across different time points, ensuring both continuity and differentiability in the resulting density estimation. Notably, it also allows to reflect the growth of a population even if the same number of cells were sampled from each time point. + +The normalization is realized by manipulating the nearest neighbor distances +`nn_distances` to reflect the deviation from an expected cell count. + +- **Type**: Optional, accepts `bool`, `list`, `array-like`, or `dict`. + +#### Options: + +- **`True`:** Normalizes to emulate an even distribution of total cell count across all time points. +- **`False`:** Retains raw cell counts at each time point for density estimation. +- **List/Array-like**: Specifies an ordered sequence of total cell count targets for each time point, starting with the earliest. +- **Dict**: Associates each unique time point with a specific total cell count target. + +#### Notes: + +- **Relative Metrics**: While this parameter adjusts for sample bias, it only requires relative cell counts for comparisons within the dataset; exact counts are not mandatory. +- **`nn_distance` Precedence**: If `nn_distance` is supplied, this parameter will be bypassed, and the provided distances will be used directly. +- The default value is `False` + + +## Enhancements + - Optimization by saving the intermediate result `Lp` in the estimators for reuse, enhancing the speed of the predictive function computation in non-Nyström strategies. + - The `DimensionalityEstimator.predict` now returns a subclass of the `mellon.Predictor` class instead of a closure. Giving access to serialization and uncertainty computations. + - Expanded testing. + - propagate logging messages and explicit logger name "mellon" everywhere + - extended parameter validation for the estimators now also applies to the `compute_L` function + - better string representation of estimators and predictors + - bugfix some edge cases + - Revise some documentation (s. b70bb04a4e921ceab63b60026b8033e384a8916a) and include [Predictor](https://mellon.readthedocs.io/en/uncertainty/predictor.html) page on sphinx doc + +## Changes + + - The mellon.Predictor class now has a method `.mean` that is an alias to `.__call__`. + - All mellon.Predictor sub classes `...ConditionalMean...` were renamed to `...Conditional...` since they now also compute `.covariance` and `.mean_covariance`. + - All generating methods for mellon.Predictor were renamed from `...conditional_mean...` to `conditional`. + - A new log message now informs that the normalization is not effective `d_method != "fractal"`. Additionally, using `normalize=True` in the density predictor triggers a warning that one has to use the non default `d_method = "fractal"` in the `DensityEstimator`. diff --git a/README.rst b/README.rst index 05f5c03..2b1f115 100644 --- a/README.rst +++ b/README.rst @@ -5,6 +5,9 @@ Mellon :target: https://zenodo.org/badge/latestdoi/558998366 .. image:: https://codecov.io/github/settylab/Mellon/branch/main/graph/badge.svg?token=TKIKXK4MPG :target: https://app.codecov.io/github/settylab/Mellon +.. image:: https://www.codefactor.io/repository/github/settylab/mellon/badge/main + :target: https://www.codefactor.io/repository/github/settylab/mellon/overview/main + :alt: CodeFactor .. image:: https://badge.fury.io/py/mellon.svg :target: https://badge.fury.io/py/mellon .. image:: https://anaconda.org/conda-forge/mellon/badges/version.svg diff --git a/docs/source/conf.py b/docs/source/conf.py index 1f52e22..80b8a28 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -47,8 +47,9 @@ def get_version(rel_path): "sphinx.ext.autodoc", "nbsphinx", "sphinx.ext.napoleon", - "sphinx_github_style", ] +if os.environ.get('READTHEDOCS') == 'True': + extensions.append("sphinx_github_style") source_suffix = [".rst", ".md"] diff --git a/docs/source/index.rst b/docs/source/index.rst index 2a2087e..2014577 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -7,6 +7,7 @@ :caption: Modules: Model + Predictor Serialization Utilities Covariance Functions @@ -21,7 +22,7 @@ notebooks/trajectory-trends_tutorial.ipynb notebooks/gene_change_analysis_tutorial.ipynb notebooks/time-series_tutorial.ipynb - + .. include:: ../../README.rst .. toctree:: diff --git a/docs/source/predictor.rst b/docs/source/predictor.rst new file mode 100644 index 0000000..2b2e98d --- /dev/null +++ b/docs/source/predictor.rst @@ -0,0 +1,98 @@ +Predictors +========== + +Predictors in the Mellon framework can be invoked directly via their `__call__` +method to produce function estimates at new locations. These predictors can +also double as Gaussian Processes, offering uncertainty estimattion options. +It also comes with serialization capabilities detailed in :ref:`serialization `. + +Basic Usage +----------- + +To generate estimates for new, out-of-sample locations, instantiate a +predictor and call it like a function: + +.. code-block:: python + :caption: Example of accessing the :class:`mellon.Predictor` from the + :class:`mellon.model.DensityEstimator` in Mellon Framework + :name: example-usage-density-predictor + + model = mellon.model.DensityEstimator(...) # Initialize the model with appropriate arguments + model.fit(X) # Fit the model to the data + predictor = model.predict # Obtain the predictor object + predicted_values = predictor(Xnew) # Generate predictions for new locations + + +Uncertainy +------------ + +If the predictor was generated with +uncertainty estimates (typically by passing `predictor_with_uncertainty=True` +and `optimizer="advi"` to the model class, e.g., :class:`mellon.model.DensityEstimator`) +then it provides methods for computing variance at these locations, and co-variance to any other +location. + +- Variance Methods: + - :meth:`mellon.Predictor.covariance` + - :meth:`mellon.Predictor.mean_covariance` + - :meth:`mellon.Predictor.uncertainty` + +Sub-Classes +----------- + +The `Predictor` module in the Mellon framework features a variety of +specialized subclasses of :class:`mellon.Predictor`. The specific subclass +instantiated by the model is contingent upon two key parameters: + +- `gp_type`: This argument determines the type of Gaussian Process used internally. +- The nature of the predicted output: This can be real-valued, strictly positive, or time-sensitive. + +The `gp_type` argument mainly affects the internal mathematical operations, +whereas the nature of the predicted value dictates the subclass's functional +capabilities: + +- **Real-valued Predictions**: Such as log-density estimates, :class:`mellon.Predictor`. +- **Positive-valued Predictions**: Such as dimensionality estimates, :class:`mellon.base_predictor.ExpPredictor`. +- **Time-sensitive Predictions**: Such as time-sensitive density estimates :class:`mellon.base_predictor.PredictorTime`. + + + +Vanilla Predictor +----------------- + +Utilized in the following methods: + +- :attr:`mellon.model.DensityEstimator.predict` +- :attr:`mellon.model.DimensionalityEstimator.predict_density` +- :attr:`mellon.model.FunctionEstimator.predict` + +.. autoclass:: mellon.Predictor + :members: + :undoc-members: + :show-inheritance: + :exclude-members: n_obs, n_input_features + +Exponential Predictor +--------------------- + +- Used in :attr:`mellon.model.DimensionalityEstimator.predict` +- Predicted values are strictly positive. Variance is expressed in log scale. + +.. autoclass:: mellon.base_predictor.ExpPredictor + :members: + :undoc-members: + :show-inheritance: + :exclude-members: n_obs, n_input_features + +Time-sensitive Predictor +------------------------ + +- Utilized in :attr:`mellon.model.TimeSensitiveDensityEstimator.predict` +- Special arguments `time` and `multi_time` permit time-specific predictions. + +.. autoclass:: mellon.base_predictor.PredictorTime + :members: + :undoc-members: + :show-inheritance: + :exclude-members: n_obs, n_input_features + diff --git a/docs/source/serialization.rst b/docs/source/serialization.rst index 6ede361..daed499 100644 --- a/docs/source/serialization.rst +++ b/docs/source/serialization.rst @@ -1,30 +1,34 @@ +.. _serialization: + Serialization ============= -The Mellon module provides a comprehensive suite of tools for serializing and deserializing estimators. This functionality is particularly useful in computational biology where models might be trained on one dataset and then used for making predictions on new datasets. The ability to serialize an estimator allows you to save the state of a model after it has been trained, and then load it later for predictions without needing to retrain the model. - -For instance, you might have a large dataset on which you train an estimator. Training might take a long time due to the size and complexity of the data. With serialization, you can save the trained estimator and then easily share it with collaborators, apply it to new data, or use it in a follow-up study. - -When an estimator is serialized, it includes a variety of metadata, including the serialization date, Python version, the class name of the estimator, and the Mellon version. -This metadata serves as a detailed record of the state of your environment at the time of serialization, which can be crucial for reproducibility in scientific research and for understanding the conditions under which the estimator was originally trained. +The Mellon module facilitates the serialization and deserialization of +predictors, streamlining model transfer and reuse. This is especially relevant +in computational biology for applying pre-trained models to new datasets. +Serialization captures essential metadata like date, Python version, estimator +class, and Mellon version, ensuring research reproducibility and context for +the original training. -All estimators in Mellon that inherit from the :class:`BaseEstimator` class, including :class:`mellon.model.DensityEstimator`, :class:`mellon.model.FunctionEstimator`, and :class:`mellon.model.DimensionalityEstimator`, have serialization capabilities. +After fitting data, all Mellon models generate a predictor via their `.predict` +property, including model classes like :class:`mellon.model.DensityEstimator`, +:class:`mellon.model.TimeSensitiveDensityEstimator`, +:class:`mellon.model.FunctionEstimator`, and +:class:`mellon.model.DimensionalityEstimator`. Predictor Class --------------- -The `Predictor` class, accessible through the `predict` property of an estimator, handles the serialization and deserialization process. - -.. autoclass:: mellon.Predictor - :members: - :undoc-members: - :show-inheritance: +All predictors inherit their serialization methods from :class:`mellon.Predictor`. Serialization to AnnData ------------------------ -Estimators can be serialized to an `AnnData`_ object. The `log_density` computation for the AnnData object shown below is a simplified example. For a more comprehensive guide on how to properly preprocess data to compute the log density, refer to our +Predictors can be serialized to an `AnnData`_ object. The `log_density` +computation for the AnnData object shown below is a simplified example. For a +more comprehensive guide on how to properly preprocess data to compute the log +density, refer to our `basic tutorial notebook `_. @@ -48,6 +52,9 @@ Estimators can be serialized to an `AnnData`_ object. The `log_density` computat Deserialization from AnnData ---------------------------- +The function :meth:`mellon.Predictor.from_dict` can deserialize the +:class:`mellon.Predictor` and any sub class. + .. code-block:: python # Load the AnnData object @@ -62,7 +69,11 @@ Deserialization from AnnData Serialization to File --------------------- -Mellon supports serialization to a human-readable JSON file and compressed file formats such as .gz (gzip) and .bz2 (bzip2). +Mellon supports serialization to a human-readable JSON file and compressed file +formats such as .gz (gzip) and .bz2 (bzip2). + +The function :meth:`mellon.Predictor.from_json` can deserialize the +:class:`mellon.Predictor` and any sub class. .. code-block:: python @@ -78,7 +89,8 @@ Mellon supports serialization to a human-readable JSON file and compressed file Deserialization from File ------------------------- -Mellon supports deserialization from JSON and compressed file formats. The compression method can be inferred from the file extension. +Mellon supports deserialization from JSON and compressed file formats. The +compression method can be inferred from the file extension. .. code-block:: python diff --git a/mellon/__init__.py b/mellon/__init__.py index 145afa5..dd29d4b 100644 --- a/mellon/__init__.py +++ b/mellon/__init__.py @@ -17,7 +17,7 @@ from . import _conditional as conditional from . import _derivatives as derivatives -__version__ = "1.3.1" +__version__ = "1.4.0" __all__ = [ "DensityEstimator", @@ -37,6 +37,8 @@ "derivatives", "__version__", ] +# Set up logger +Log() # Set default configuration at import time jaxconfig.update("jax_enable_x64", True) diff --git a/mellon/_conditional.py b/mellon/_conditional.py index dd719ed..59fa715 100644 --- a/mellon/_conditional.py +++ b/mellon/_conditional.py @@ -1,17 +1,23 @@ from .conditional import ( - FullConditionalMean, - FullConditionalMeanTime, - LandmarksConditionalMean, - LandmarksConditionalMeanTime, - LandmarksConditionalMeanCholesky, - LandmarksConditionalMeanCholeskyTime, + FullConditional, + ExpFullConditional, + FullConditionalTime, + LandmarksConditional, + ExpLandmarksConditional, + LandmarksConditionalTime, + LandmarksConditionalCholesky, + ExpLandmarksConditionalCholesky, + LandmarksConditionalCholeskyTime, ) __all__ = [ - "FullConditionalMean", - "FullConditionalMeanTime", - "LandmarksConditionalMean", - "LandmarksConditionalMeanTime", - "LandmarksConditionalMeanCholesky", - "LandmarksConditionalMeanCholeskyTime", + "FullConditional", + "ExpFullConditional", + "FullConditionalTime", + "LandmarksConditional", + "ExpLandmarksConditional", + "LandmarksConditionalTime", + "LandmarksConditionalCholesky", + "ExpLandmarksConditionalCholesky", + "LandmarksConditionalCholeskyTime", ] diff --git a/mellon/_inference.py b/mellon/_inference.py index f8a109b..ae5297d 100644 --- a/mellon/_inference.py +++ b/mellon/_inference.py @@ -6,9 +6,10 @@ minimize_adam, minimize_lbfgsb, compute_log_density_x, - compute_conditional_mean, - compute_conditional_mean_times, - compute_conditional_mean_explog, + compute_parameter_cov_factor, + compute_conditional, + compute_conditional_times, + compute_conditional_explog, generate_gaussian_sample, calculate_gaussian_logpdf, calculate_elbo, @@ -24,9 +25,10 @@ "minimize_adam", "minimize_lbfgsb", "compute_log_density_x", - "compute_conditional_mean", - "compute_conditional_mean_times", - "compute_conditional_mean_explog", + "compute_parameter_cov_factor", + "compute_conditional", + "compute_conditional_times", + "compute_conditional_explog", "generate_gaussian_sample", "calculate_gaussian_logpdf", "calculate_elbo", diff --git a/mellon/_parameters.py b/mellon/_parameters.py index d95fc48..4a8e26e 100644 --- a/mellon/_parameters.py +++ b/mellon/_parameters.py @@ -1,7 +1,10 @@ from .parameters import ( compute_d, compute_d_factal, + compute_rank, + compute_n_landmarks, compute_landmarks, + compute_gp_type, compute_nn_distances, compute_nn_distances_within_time_points, compute_distances, @@ -17,7 +20,10 @@ __all__ = [ "compute_d", "compute_d_factal", + "compute_rank", + "compute_n_landmarks", "compute_landmarks", + "compute_gp_type", "compute_nn_distances", "compute_nn_distances_within_time_points", "compute_distances", diff --git a/mellon/_util.py b/mellon/_util.py index 0d9bad9..7073f55 100644 --- a/mellon/_util.py +++ b/mellon/_util.py @@ -5,6 +5,8 @@ local_dimensionality, Log, set_jax_config, + GaussianProcessType, + object_str, ) __all__ = [ @@ -14,4 +16,6 @@ "local_dimensionality", "Log", "set_jax_config", + "GaussianProcessType", + "object_str", ] diff --git a/mellon/base_cov.py b/mellon/base_cov.py index 83efd55..9b7368c 100644 --- a/mellon/base_cov.py +++ b/mellon/base_cov.py @@ -1,14 +1,17 @@ import sys +import logging +from jax import vmap +from jax.numpy import expand_dims, reshape from abc import ABC, abstractmethod from importlib import import_module from datetime import datetime import json -from .util import Log, make_serializable, deserialize +from .util import make_serializable, deserialize MELLON_NAME = __name__.split(".")[0] -logger = Log() +logger = logging.getLogger("mellon") class Covariance(ABC): @@ -40,6 +43,30 @@ def k(x, y): def __call__(self, x, y): return self.k(x, y) + def diag(self, x): + """ + Compute the diagonal of a covariance matrix. + + This function expands the input vectors, maps a function over them, + reshapes the result and returns the diagonal of the covariance matrix. + + Parameters + ---------- + x : ndarray + Input array where the first dimension is the sample dimension and + the second dimension corresponds to the state dimensions of the samples (cells). + + Returns + ------- + diag : ndarray + The diagonal of the covariance matrix. + """ + + x = expand_dims(x, 1) + res = vmap(self.k)(x, x) + diag = reshape(res, res.shape[:-2]) + return diag + def __add__(self, other): return Add(self, other) diff --git a/mellon/base_model.py b/mellon/base_model.py index 8bd961b..b022b2f 100644 --- a/mellon/base_model.py +++ b/mellon/base_model.py @@ -1,5 +1,5 @@ +import logging from .cov import Matern52 -from .decomposition import DEFAULT_RANK, DEFAULT_METHOD from .inference import ( minimize_adam, run_advi, @@ -10,17 +10,21 @@ DEFAULT_JIT, ) from .parameters import ( + compute_rank, + compute_n_landmarks, compute_landmarks, + compute_gp_type, compute_nn_distances, compute_ls, compute_cov_func, + compute_Lp, compute_L, - DEFAULT_N_LANDMARKS, ) from .util import ( test_rank, + object_str, DEFAULT_JITTER, - Log, + GaussianProcessType, ) from .validation import ( _validate_positive_int, @@ -30,9 +34,12 @@ _validate_string, _validate_bool, _validate_array, + _validate_float_or_iterable_numerical, +) +from .parameter_validation import ( + _validate_params, _validate_cov_func_curry, _validate_cov_func, - _validate_float_or_iterable_numerical, ) @@ -40,7 +47,7 @@ RANK_FRACTION_THRESHOLD = 0.8 SAMPLE_LANDMARK_RATIO = 10 -logger = Log() +logger = logging.getLogger("mellon") class BaseEstimator: @@ -51,41 +58,47 @@ class BaseEstimator: def __init__( self, cov_func_curry=DEFAULT_COV_FUNC, - n_landmarks=DEFAULT_N_LANDMARKS, - rank=DEFAULT_RANK, - method=DEFAULT_METHOD, + n_landmarks=None, + rank=None, jitter=DEFAULT_JITTER, optimizer=DEFAULT_OPTIMIZER, n_iter=DEFAULT_N_ITER, init_learn_rate=DEFAULT_INIT_LEARN_RATE, landmarks=None, + gp_type=None or GaussianProcessType, nn_distances=None, d=None, mu=0, ls=None, ls_factor=1, cov_func=None, + Lp=None, L=None, initial_value=None, + predictor_with_uncertainty=False, jit=DEFAULT_JIT, + check_rank=None, ): self.cov_func_curry = _validate_cov_func_curry( cov_func_curry, cov_func, "cov_func_curry" ) - self.n_landmarks = _validate_positive_int(n_landmarks, "n_landmarks") - self.rank = _validate_float_or_int(rank, "rank") - self.method = _validate_string( - method, "method", choices={"percent", "fixed", "auto"} + self.n_landmarks = _validate_positive_int( + n_landmarks, "n_landmarks", optional=True ) + self.rank = _validate_float_or_int(rank, "rank", optional=True) self.jitter = _validate_positive_float(jitter, "jitter") self.landmarks = _validate_array(landmarks, "landmarks", optional=True) + self.gp_type = GaussianProcessType.from_string(gp_type, optional=True) self.nn_distances = _validate_array(nn_distances, "nn_distances", optional=True) self.mu = _validate_float(mu, "mu", optional=True) self.ls = _validate_positive_float(ls, "ls", optional=True) self.ls_factor = _validate_positive_float(ls_factor, "ls_factor") self.cov_func = _validate_cov_func(cov_func, "cov_func", optional=True) + self.Lp = _validate_array(Lp, "Lp", optional=True) self.L = _validate_array(L, "L", optional=True) - self.d = _validate_float_or_iterable_numerical(d, "d", optional=True) + self.d = _validate_float_or_iterable_numerical( + d, "d", optional=True, positive=True + ) self.initial_value = _validate_array( initial_value, "initial_value", optional=True ) @@ -96,7 +109,11 @@ def __init__( self.init_learn_rate = _validate_positive_float( init_learn_rate, "init_learn_rate" ) + self.predictor_with_uncertainty = _validate_bool( + predictor_with_uncertainty, "predictor_with_uncertainty" + ) self.jit = _validate_bool(jit, "jit") + self.check_rank = _validate_bool(check_rank, "check_rank", optional=True) self.x = None self.pre_transformation = None @@ -105,23 +122,35 @@ def __str__(self): def __repr__(self): name = self.__class__.__name__ + landmarks = object_str(self.landmarks, ["landmarks", "dims"]) + Lp = object_str(self.Lp, ["landmarks", "landmarks"]) + L = object_str(self.L, ["cells", "ranks"]) + nn_distances = object_str(self.nn_distances, ["cells"]) + initial_value = object_str(self.initial_value, ["ranks"]) + d = object_str(self.d, ["cells"]) string = ( f"{name}(" - f"cov_func_curry={self.cov_func_curry}, " - f"n_landmarks={self.n_landmarks}, " - f"rank={self.rank}, " - f"jitter={self.jitter}, " - f"landmarks={self.landmarks}, " + f"\n cov_func_curry={self.cov_func_curry}," + f"\n n_landmarks={self.n_landmarks}," + f"\n rank={self.rank}," + f"\n gp_type={self.gp_type}," + f"\n jitter={self.jitter}, " + f"\n optimizer={self.optimizer}," + f"\n landmarks={landmarks}," + f"\n nn_distances={nn_distances}," + f"\n d={d}," + f"\n mu={self.mu}," + f"\n ls={self.ls}," + f"\n ls_factor={self.ls_factor}," + f"\n cov_func={self.cov_func}," + f"\n Lp={Lp}," + f"\n L={L}," + f"\n initial_value={initial_value}," + f"\n predictor_with_uncertainty={self.predictor_with_uncertainty}," + f"\n jit={self.jit}," + f"\n check_rank={self.check_rank}," + "\n)" ) - if self.nn_distances is None: - string += "nn_distances=None, " - else: - string += "nn_distances=nn_distances, " - string += f"mu={self.mu}, " f"ls={self.mu}, " f"cov_func={self.cov_func}, " - if self.L is None: - string += "L=None, " - else: - string += "L=L, " return string def __call__(self, x=None): @@ -162,6 +191,13 @@ def set_x(self, x): self.x = _validate_array(x, "x") return self.x + def _compute_n_landmarks(self): + gp_type = self.gp_type + n_samples = self.x.shape[0] + landmarks = self.landmarks + n_landmarks = compute_n_landmarks(gp_type, n_samples, landmarks) + return n_landmarks + def _compute_landmarks(self): x = self.x n_landmarks = self.n_landmarks @@ -176,6 +212,18 @@ def _compute_landmarks(self): landmarks = compute_landmarks(x, n_landmarks=n_landmarks) return landmarks + def _compute_rank(self): + gp_type = self.gp_type + rank = compute_rank(gp_type) + return rank + + def _compute_gp_type(self): + n_landmarks = self.n_landmarks + rank = self.rank + n_samples = self.x.shape[0] + gp_type = compute_gp_type(n_landmarks, rank, n_samples) + return gp_type + def _compute_nn_distances(self): x = self.x logger.info("Computing nearest neighbor distances.") @@ -195,62 +243,67 @@ def _compute_cov_func(self): logger.info("Using covariance function %s.", str(cov_func)) return cov_func + def _compute_Lp(self): + """ + This function calculates the lower triangular matrix L that is needed for + computations involving the predictive of the Gaussian Process model. + + It has the side effect of settling self.L + """ + x = self.x + cov_func = self.cov_func + gp_type = self.gp_type + landmarks = self.landmarks + jitter = self.jitter + Lp = compute_Lp( + x, + cov_func, + gp_type, + landmarks, + sigma=0, + jitter=jitter, + ) + return Lp + def _compute_L(self): """ This function calculates the lower triangular matrix L that is needed for computations involving the covariance matrix of the Gaussian Process model. - """ - # Extract instance attributes + It has the side effect of settling self.Lp + """ x = self.x cov_func = self.cov_func + gp_type = self.gp_type landmarks = self.landmarks - n_samples = x.shape[0] - n_landmarks = n_samples if landmarks is None else landmarks.shape[0] + Lp = self.Lp rank = self.rank - method = self.method jitter = self.jitter - - is_rank_full = ( - isinstance(rank, int) - and rank == n_landmarks - or isinstance(rank, float) - and rank == 1.0 + check_rank = self.check_rank + + L = compute_L( + x, + cov_func, + gp_type, + landmarks=landmarks, + Lp=Lp, + rank=rank, + sigma=0, + jitter=jitter, ) - # Log the method and rank used for computation - if not is_rank_full and method != "fixed": - logger.info( - f'Computing rank reduction using "{method}" method ' - f"retaining > {rank:.2%} of variance." - ) - elif not is_rank_full: - logger.info( - f'Computing rank reduction to rank {rank} using "{method}" method.' - ) - - try: - # Compute the lower triangular matrix L - L = compute_L( - x, - cov_func, - landmarks=landmarks, - rank=rank, - method=method, - jitter=jitter, - ) - except Exception as e: - logger.error(f"Error during computation of L: {e}") - raise - new_rank = L.shape[1] + n_samples = x.shape[0] + if landmarks is None: + n_landmarks = n_samples + else: + n_landmarks = landmarks.shape[0] # Check if the new rank is too high in comparison to the number of landmarks if ( - not is_rank_full - and method != "fixed" - and new_rank > (rank * RANK_FRACTION_THRESHOLD * n_landmarks) - ): + gp_type == GaussianProcessType.SPARSE_NYSTROEM + or gp_type == GaussianProcessType.FULL_NYSTROEM + ) and new_rank > (rank * RANK_FRACTION_THRESHOLD * n_landmarks): logger.warning( f"Shallow rank reduction from {n_landmarks:,} to {new_rank:,} " "indicates underrepresentation by landmarks. Consider " @@ -259,10 +312,10 @@ def _compute_L(self): # Check if the number of landmarks is sufficient for the number of samples if ( - is_rank_full - and n_landmarks is not None + check_rank is None + and gp_type == GaussianProcessType.SPARSE_CHOLESKY and SAMPLE_LANDMARK_RATIO * n_landmarks < n_samples - ): + ) or (check_rank is not None and check_rank): logger.info( "Estimating approximation accuracy " f"since {n_samples:,} samples are more than {SAMPLE_LANDMARK_RATIO} x " @@ -270,8 +323,20 @@ def _compute_L(self): ) test_rank(L, threshold=RANK_FRACTION_THRESHOLD) logger.info(f"Using rank {new_rank:,} covariance representation.") + return L + def _validate_parameter(self): + """ + Make sure there are no contradictions in the parameter settings. + """ + rank = self.rank + gp_type = self.gp_type + n_samples = self.x.shape[0] + n_landmarks = self.n_landmarks + landmarks = self.landmarks + _validate_params(rank, gp_type, n_samples, n_landmarks, landmarks) + def _run_inference(self): function = self.loss_func initial_value = self.initial_value @@ -288,6 +353,7 @@ def _run_inference(self): jit=self.jit, ) self.pre_transformation = results.pre_transformation + self.pre_transformation_std = None self.opt_state = results.opt_state self.losses = results.losses elif optimizer == "advi": @@ -308,6 +374,7 @@ def _run_inference(self): jit=self.jit, ) self.pre_transformation = results.pre_transformation + self.pre_transformation_std = None self.opt_state = results.opt_state self.losses = [ results.loss, diff --git a/mellon/base_predictor.py b/mellon/base_predictor.py index 6ccedec..de44683 100644 --- a/mellon/base_predictor.py +++ b/mellon/base_predictor.py @@ -1,6 +1,10 @@ import sys +import logging from importlib import import_module +from packaging import version from abc import ABC, abstractmethod +from functools import wraps +from typing import Union, Set, List from datetime import datetime import gzip @@ -8,56 +12,96 @@ import json +from jax.numpy import exp, log from .base_cov import Covariance from .util import ( - Log, make_serializable, deserialize, ensure_2d, make_multi_time_argument, + object_str, ) from .derivatives import ( gradient, hessian, hessian_log_determinant, ) -from .validation import _validate_time_x, _validate_float, _validate_array +from .validation import ( + _validate_time_x, + _validate_float, + _validate_array, + _validate_bool, +) -logger = Log() +logger = logging.getLogger("mellon") class Predictor(ABC): """ Abstract base class for predictor models. It provides a common interface for all subclasses, which are expected to - implement the `__call__` method for making predictions. + implement the `_mean` method for making predictions. An instance `predictor` of a subclass of `Predictor` can be used to make a prediction by calling it with input data `x`: >>> y = predictor(x) - It is the responsibility of subclasses to define the behaviour of `_predict`. + It is the responsibility of subclasses to define the behaviour of `_mean`. Methods ------- - __call__(x: Union[array-like, pd.DataFrame]): - This makes predictions for an input `x`. The input data type can be either an array-like object - (like list or numpy array) or a pandas DataFrame. + __call__(x: Union[array-like, pd.DataFrame], normalize: bool = False): + + Equivalent to calling the `mean` method, this uses the trained model to make + predictions based on the input array, x. + + The prediction corresponds to the mean of the Gaussian Process conditional + distribution of predictive functions. + + The input array must be 2D with the length of its second dimension matching the number + of features used in training the model. Parameters ---------- - x : array-like or pandas.DataFrame - The input data on which to make a prediction. + x : array-like + The input data to the predictor, having shape (n_samples, n_input_features). + normalize : bool + Optional normalization by subtracting log(self.n_obs) + (number of cells trained on), applicable only for cell-state density predictions. + Default is False. Returns ------- - This method returns the predictions made by the model on the input data `x`. - The prediction is usually array-like with as many entries as `x.shape[0]`. + array + The predicted output generated by the model. + + Raises + ------ + ValueError + If the number of features in 'x' does not align with the number of features the + predictor was trained on. + + Attributes + ---------- + n_obs : int + The number of samples or cells that the model was trained on. This attribute is critical + for normalization purposes, particularly when the `normalize` parameter in the `__call__` + method is set to `True`. + + n_input_features : int + The number of features/dimensions of the cell-state representation the predictor was + trained on. This is used for validation of input data. """ # number of features of input data (x.shape[1]) to be specified in __init__ n_input_features: int + # number of observations trained on (x.shape[0]) to be specified in __init__ + n_obs: int + + # a set of attribute names that should be saved to reconstruct the object + _state_variables: Union[Set, List] + @abstractmethod def __init__(self): """Initialize the predictor. Must be overridden by subclasses.""" @@ -67,29 +111,35 @@ def __str__(self): return self.__repr__() def __repr__(self): + n_obs = "None" if self.n_obs is None else f"{self.n_obs:,}" string = ( 'A predictor of class "' + self.__class__.__name__ + '" with covariance function "' + repr(self.cov_func) - + '" and data:\n' + + f'" trained on {n_obs} observations ' + + f"with {self.n_input_features:,} features " + + "and data:\n" + "\n".join( [ - str(key) + ": " + repr(getattr(self, key)) - for key in self._data_dict().keys() + str(key) + ": " + object_str(v) + for key, v in self._data_dict().items() ] ) ) return string @abstractmethod - def _predict(self, *args, **kwars): + def _mean(self, *args, **kwargs): """Call the predictor. Must be overridden by subclasses.""" - def __call__(self, x): + def mean(self, x, normalize=False): """ Use the trained model to make a prediction based on the input array, x. + The prediction represents the mean of the Gaussian Process conditional + distribution of predictive functions. + The input array should be 2D with its second dimension's length equal to the number of features used in training the model. @@ -98,6 +148,10 @@ def __call__(self, x): x : array-like The input data to the predictor. The array should have shape (n_samples, n_input_features). + normalize : bool + Whether to normalize the value by subtracting log(self.n_obs) + (number of cells trained on). Applicable only for cell-state density predictions. + Default is False. Returns ------- @@ -112,27 +166,136 @@ def __call__(self, x): """ x = _validate_array(x, "x") x = ensure_2d(x) + normalize = _validate_bool(normalize, "normalize") + if x.shape[1] != self.n_input_features: raise ValueError( f"The predictor was trained on data with {self.n_input_features} features. " f"However, the provided input data has {x.shape[1]} features. " "Please ensure that the input data has the same number of features as the training data." ) - return self._predict(x) + if normalize: + if self.n_obs is None or self.n_obs == 0: + message = ( + "Cannot normalize without n_obs. Please set self.n_obs to the number " + "of samples/cells trained on to enable normalization." + ) + logger.error(message) + raise ValueError(message) + logger.warning( + 'The normalization is only effective if the density was trained with d_method="fractal".' + ) + return self._mean(x) - log(self.n_obs) + else: + return self._mean(x) + + __call__ = mean @abstractmethod + def _covariance(self, *args, **kwars): + """Compute the covariance. Must be overridden by subclasses.""" + + def covariance(self, x, diag=True): + """ + Computes the covariance of the Gaussian Process distribution of functions + over new data points or cell states. + + Parameters + ---------- + x : array-like, shape (n_samples, n_features) + The new data points for which to compute the covariance. + diag : boolean, optional (default=True) + Whether to return the variance (True) or the full covariance matrix (False). + + Returns + ------- + var : array-like, shape (n_samples,) + If diag=True, returns the variances for each sample. + cov : array-like, shape (n_samples, n_samples) + If diag=False, returns the full covariance matrix between samples. + """ + x = _validate_array(x, "x") + x = ensure_2d(x) + if x.shape[1] != self.n_input_features: + raise ValueError( + f"The predictor was trained on data with {self.n_input_features} features. " + f"However, the provided input data has {x.shape[1]} features. " + "Please ensure that the input data has the same number of features as the training data." + ) + return self._covariance(x, diag=diag) + + @abstractmethod + def _mean_covariance(self, *args, **kwars): + """Compute the covariance of the mean. Must be overridden by subclasses.""" + + def mean_covariance(self, x, diag=True): + """ + Computes the uncertainty of the mean of the Gaussian process induced by + the uncertainty of the latent representation of the mean function. + + Parameters + ---------- + x : array-like, shape (n_samples, n_features) + The new data points for which to compute the uncertainty. + diag : boolean, optional (default=True) + Whether to compute the variance (True) or the full covariance matrix (False). + + Returns + ------- + var : array-like, shape (n_samples,) + If diag=True, returns the variances for each sample. + cov : array-like, shape (n_samples, n_samples) + If diag=False, returns the full covariance matrix between samples. + """ + x = _validate_array(x, "x") + x = ensure_2d(x) + if x.shape[1] != self.n_input_features: + raise ValueError( + f"The predictor was trained on data with {self.n_input_features} features. " + f"However, the provided input data has {x.shape[1]} features. " + "Please ensure that the input data has the same number of features as the training data." + ) + return self._mean_covariance(x, diag=diag) + + def uncertainty(self, x, diag=True): + """ + Computes the total uncertainty of the predicted values quantified by their variance + or covariance. + + The total uncertainty is defined by `.covariance` + `.mean_covariance`. + + Parameters + ---------- + x : array-like, shape (n_samples, n_features) + The new data points for which to compute the uncertainty. + diag : bool, optional (default=True) + Whether to compute the variance (True) or the full covariance matrix (False). + + Returns + ------- + var : array-like, shape (n_samples,) if diag=True + The variances for each sample in the new data points. + cov : array-like, shape (n_samples, n_samples) if diag=False + The full covariance matrix between the samples in the new data points. + """ + x = _validate_array(x, "x") + x = ensure_2d(x) + if x.shape[1] != self.n_input_features: + raise ValueError( + f"The predictor was trained on data with {self.n_input_features} features. " + f"However, the provided input data has {x.shape[1]} features. " + "Please ensure that the input data has the same number of features as the training data." + ) + return self._covariance(x, diag=diag) + self._mean_covariance(x, diag=diag) + def _data_dict(self): - """Return a dictionary containing the predictor's state data. + """Returns a dictionary containing the predictor's state data. All arrays nee to be (jax) numpy arrays for serialization. - This method must be implemented by subclasses. It should return a - dictionary where each key-value pair corresponds to an attribute of - the predictor and its current state. - :return: A dictionary containing the predictor's state data. :rtype: dict """ - pass + return {key: getattr(self, key) for key in self._state_variables} def gradient(self, x, jit=True): R""" @@ -148,7 +311,8 @@ def gradient(self, x, jit=True): """ x = _validate_array(x, "x") x = ensure_2d(x) - return gradient(self.__call__, x, jit=jit) + + return gradient(self._mean, x, jit=jit) def hessian(self, x, jit=True): R""" @@ -200,7 +364,13 @@ def __getstate__(self): metaversion = getattr(metamodule, "__version__", "NA") version = getattr(module, "__version__", metaversion) data = self._data_dict() - data.update({"n_input_features": self.n_input_features}) + data.update( + { + "n_input_features": self.n_input_features, + "n_obs": self.n_obs, + "_state_variables": self._state_variables, + } + ) data = {k: make_serializable(v) for k, v in data.items()} state = { @@ -325,6 +495,23 @@ def from_dict(cls, data_dict): """ clsname = data_dict["metadata"]["classname"] module_name = data_dict["metadata"]["module_name"] + module_version = data_dict["metadata"]["module_version"] + + if version.parse(module_version) < version.parse("1.4.0"): + message = ( + f"Loading a predictor written by mellon {module_version} < 1.4.0. " + "Please set predictor.n_obs to enable normalization." + ) + logger.warning(message) + if module_name == "mellon.conditional": + clsname = clsname.replace("ConditionalMean", "Conditional") + data_dict["data"]["n_obs"] = data_dict["data"].get("n_obs", None) + state_vars = set(data_dict["data"].keys()) - { + "n_input_features", + } + data_dict["data"]["_state_variables"] = data_dict["data"].get( + "_state_variables", state_vars + ) module = import_module(module_name) Subclass = getattr(module, clsname) @@ -348,6 +535,82 @@ def from_json_str(cls, json_str): return cls.from_dict(data_dict) +class ExpPredictor(Predictor): + """ + Abstract base class for predictor models which returs the exponent of its `_mean` method upon a call. + + An instance `predictor` of a subclass of `Predictor` can be used to make a prediction by calling it with input data `x`: + + >>> y = predictor(x) + + It is the responsibility of subclasses to define the behaviour of `_mean`. + """ + + def mean(self, x, logscale=False): + """ + Use the trained model to make a prediction based on the input array, x. + + The input array should be 2D with its second dimension's length + equal to the number of features used in training the model. + + Parameters + ---------- + x : array-like + The input data to the predictor. + The array should have shape (n_samples, n_input_features). + + logscale : bool + Weather the predicted value should be returned in log scale. + Default is False. + + Returns + ------- + array + The predicted output generated by the model. + + Raises + ------ + ValueError + If the number of features in 'x' does not match the + number of features the predictor was trained on. + """ + x = _validate_array(x, "x") + logscale = _validate_bool(logscale, "logscale") + x = ensure_2d(x) + if x.shape[1] != self.n_input_features: + raise ValueError( + f"The predictor was trained on data with {self.n_input_features} features. " + f"However, the provided input data has {x.shape[1]} features. " + "Please ensure that the input data has the same number of features as the training data." + ) + if logscale: + return self._mean(x) + return exp(self._mean(x)) + + __call__ = mean + + @wraps(Predictor.covariance) + def covariance(self, *args, **kwargs): + logger.warning( + "The covariance will be computed for the predicted value in log scale." + ) + return super().covariance(*args, **kwargs) + + @wraps(Predictor.mean_covariance) + def mean_covariance(self, *args, **kwargs): + logger.warning( + "The mean_covariance will be computed for the predicted value in log scale." + ) + return super().mean_covariance(*args, **kwargs) + + @wraps(Predictor.uncertainty) + def uncertainty(self, *args, **kwargs): + logger.warning( + "The uncertainty will be computed for the predicted value in log scale." + ) + return super().uncertainty(*args, **kwargs) + + class PredictorTime(Predictor): """ Abstract base class for predictor models with a time covariate. @@ -357,44 +620,81 @@ class PredictorTime(Predictor): >>> y = predictor(x, time) - It is the responsibility of subclasses to define the behaviour of `_predict`. + It is the responsibility of subclasses to define the behaviour of `_mean`. Methods ------- - __call__(x: Union[array-like, pd.DataFrame]): - This makes predictions for an input `x`. The input data type can be either an array-like object - (like list or numpy array) or a pandas DataFrame. + __call__(x: Union[array-like, pd.DataFrame], normalize: bool = False): + + Equivalent to calling the `mean` method, this uses the trained model to make + predictions based on the input array 'Xnew', + considering the specified 'time' or 'multi_time'. + + The predictions represent the mean of the Gaussian Process conditional + distribution of predictive functions. + + If 'time' is a scalar, it will be converted into a 1D array of the same size as 'Xnew'. Parameters ---------- - x : array-like or pandas.DataFrame - The input data on which to make a prediction. + Xnew : array-like + The new data points for prediction. time : scalar or array-like, optional - The time points associated with each cell/row in 'x'. - If 'time' is a scalar, it will be converted into a 1D array of the same size as 'x'. + The time points associated with each row in 'Xnew'. + If 'time' is a scalar, it will be converted into a 1D array of the same size as 'Xnew'. + normalize : bool + Optional normalization by subtracting log(self.n_obs) + (number of cells trained on), applicable only for cell-state density predictions. + Default is False. Returns ------- - This method returns the predictions made by the model on the input data `x` and `time`. - The prediction is usually array-like with as many entries as `x.shape[0]`. + array + The predicted output generated by the model. + + Raises + ------ + ValueError + If the number of features in 'x' does not align with the number of features the + predictor was trained on. + + Attributes + ---------- + n_obs : int + The average number of samples or cells per time point that the model was trained on. + This attribute is critical for normalization purposes, particularly when the `normalize` + parameter in the `__call__` method is set to `True`. + + n_input_features : int + The number of features/dimensions of the cell-state representation the predictor was + trained on. This is used for validation of input data. """ @make_multi_time_argument - def __call__(self, Xnew, time=None): + def mean(self, Xnew, time=None, normalize=False): """ - Call method to use the class instance as a function. This method - deals with an optional 'time' argument. - If 'time' is a scalar, it converts it to a 1D array of the same size as 'Xnew'. + Use the trained model to make predictions based on the input array 'Xnew', + considering the specified 'time' or 'multi_time'. + + The predictions represent the mean of the Gaussian Process conditional + distribution of predictive functions. + + If 'time' is a scalar, it will be converted into a 1D array of the same size as 'Xnew'. Parameters ---------- Xnew : array-like The new data points for prediction. time : scalar or array-like, optional - The time points associated with each cell/row in 'Xnew'. + The time points associated with each row in 'Xnew'. If 'time' is a scalar, it will be converted into a 1D array of the same size as 'Xnew'. + normalize : bool + Whether to normalize the value by subtracting log(self.n_obs) + (number of cells trained on). Applicable only for cell-state density predictions. + Default is False. multi_time : array-like, optional - If 'multi_time' is specified then a prediction will be made for each row. + If 'multi_time' is specified then a prediction for all states in x will + be made for each time value in multi_time separatly. Returns ------- @@ -410,8 +710,123 @@ def __call__(self, Xnew, time=None): Xnew = _validate_time_x( Xnew, time, n_features=self.n_input_features, cast_scalar=True ) + normalize = _validate_bool(normalize, "normalize") + + if normalize: + if self.n_obs is None or self.n_obs == 0: + message = ( + "Cannot normalize without n_obs. Please set self.n_obs to the number " + "of samples/cells (per time point) trained on to enable normalization." + ) + logger.error(message) + raise ValueError(message) + logger.warning( + 'The normalization is only effective if the density was trained with d_method="fractal".' + ) + return self._mean(Xnew) - log(self.n_obs) + else: + return self._mean(Xnew) + + __call__ = mean + + @make_multi_time_argument + def covariance(self, Xnew, time=None, diag=True): + """ + Computes the covariance of the Gaussian Process distribution of functions + over new data points or cell states. + + Parameters + ---------- + Xnew : array-like, shape (n_samples, n_features) + The new data points for which to compute the covariance. + time : scalar or array-like, optional + The time points associated with each cell/row in 'Xnew'. + If 'time' is a scalar, it will be converted into a 1D array of the same size as 'Xnew'. + diag : boolean, optional (default=True) + Whether to return the variance (True) or the full covariance matrix (False). + multi_time : array-like, optional + If 'multi_time' is specified then a covariance for all states in x will + be computed for each time value in multi_time separatly. + + Returns + ------- + var : array-like, shape (n_samples,) + If diag=True, returns the variances for each sample. + cov : array-like, shape (n_samples, n_samples) + If diag=False, returns the full covariance matrix between samples. + """ + # if time is a scalar, convert it into a 1D array of the same size as Xnew + Xnew = _validate_time_x( + Xnew, time, n_features=self.n_input_features, cast_scalar=True + ) + return self._covariance(Xnew, diag=diag) + + @make_multi_time_argument + def mean_covariance(self, Xnew, time=None, diag=True): + """ + Computes the uncertainty of the mean of the Gaussian process induced by + the uncertainty of the latent representation of the mean function. + + Parameters + ---------- + Xnew : array-like, shape (n_samples, n_features) + The new data points for which to compute the uncertainty. + time : scalar or array-like, optional + The time points associated with each cell/row in 'Xnew'. + If 'time' is a scalar, it will be converted into a 1D array of the same size as 'Xnew'. + diag : boolean, optional (default=True) + Whether to compute the variance (True) or the full covariance matrix (False). + multi_time : array-like, optional + If 'multi_time' is specified then a mean covariance for all states in x will + be computed for each time value in multi_time separatly. + + Returns + ------- + var : array-like, shape (n_samples,) + If diag=True, returns the variances for each sample. + cov : array-like, shape (n_samples, n_samples) + If diag=False, returns the full covariance matrix between samples. + """ + # if time is a scalar, convert it into a 1D array of the same size as Xnew + Xnew = _validate_time_x( + Xnew, time, n_features=self.n_input_features, cast_scalar=True + ) + return self._mean_covariance(Xnew, diag=diag) - return self._predict(Xnew) + @make_multi_time_argument + def uncertainty(self, Xnew, time=None, diag=True): + """ + Computes the total uncertainty of the predicted values quantified by their variance + or covariance. + + The total uncertainty is defined by `.covariance` + `.mean_covariance`. + + Parameters + ---------- + Xnew : array-like, shape (n_samples, n_features) + The new data points for which to compute the uncertainty. + time : scalar or array-like, optional + The time points associated with each cell/row in 'Xnew'. + If 'time' is a scalar, it will be converted into a 1D array of the same size as 'Xnew'. + diag : bool, optional (default=True) + Whether to compute the variance (True) or the full covariance matrix (False). + multi_time : array-like, optional + If 'multi_time' is specified then a uncertainty for all states in x will + be computed for each time value in multi_time separatly. + + Returns + ------- + var : array-like, shape (n_samples,) if diag=True + The variances for each sample in the new data points. + cov : array-like, shape (n_samples, n_samples) if diag=False + The full covariance matrix between the samples in the new data points. + """ + Xnew = _validate_time_x( + Xnew, time, n_features=self.n_input_features, cast_scalar=True + ) + return self._covariance(Xnew, diag=diag) + self._mean_covariance( + Xnew, diag=diag + ) @make_multi_time_argument def time_derivative( @@ -437,10 +852,11 @@ def time_derivative( specific time point for all data points in `x`. If `time` is an array, it should be 1-D and the time derivative will be computed for all data-points at the corresponding time in the array. - multi_time : array-like, optional - If 'multi_time' is specified then a prediction will be made for each row. jit : bool, optional If True, use JAX's just-in-time (JIT) compilation to speed up the computation. Defaults to True. + multi_time : array-like, optional + If 'multi_time' is specified then a time derivative for all states in x will + be computed for each time value in multi_time separatly. Returns ------- @@ -465,10 +881,11 @@ def gradient(self, x, time, jit=True): Data points at which the gradient is to be computed. time : float Specific time point at which to compute the gradient. - multi_time : array-like, optional - If 'multi_time' is specified then the computation will be made for each row. jit : bool, optional If True, use JAX's just-in-time (JIT) compilation to speed up the computation. Defaults to True. + multi_time : array-like, optional + If 'multi_time' is specified then a gradient for all states in x will + be made for each time value in multi_time separatly. Returns ------- @@ -479,7 +896,7 @@ def gradient(self, x, time, jit=True): time = _validate_float(time, "time", optional=True) def dens_at(x): - return self.__call__(x, time) + return self.mean(x, time) return gradient(dens_at, x, jit=jit) @@ -498,6 +915,9 @@ def hessian(self, x, time, jit=True): If 'multi_time' is specified then the computation will be made for each row. jit : bool, optional If True, use JAX's just-in-time (JIT) compilation to speed up the computation. Defaults to True. + multi_time : array-like, optional + If 'multi_time' is specified then a hessian for all states in x will + be computed for each time value in multi_time separatly. Returns ------- @@ -524,10 +944,11 @@ def hessian_log_determinant(self, x, time, jit=True): Data points at which the log determinant is to be computed. time : float Specific time point at which to compute the log determinant. - multi_time : array-like, optional - If 'multi_time' is specified then the computation will be made for each row. jit : bool, optional If True, use JAX's just-in-time (JIT) compilation to speed up the computation. Defaults to True. + multi_time : array-like, optional + If 'multi_time' is specified then a log determinant for all states in x will + be computed for each time value in multi_time separatly. Returns ------- diff --git a/mellon/compute_ls_time.py b/mellon/compute_ls_time.py index 779f5ad..4ac4d1e 100644 --- a/mellon/compute_ls_time.py +++ b/mellon/compute_ls_time.py @@ -1,12 +1,12 @@ +import logging from jax.numpy import exp, unique, corrcoef, zeros, abs, stack from jax.numpy import sum as arraysum from jax.numpy.linalg import norm from jaxopt import ScipyMinimize from .density_estimator import DensityEstimator -from .util import Log from .validation import _validate_time_x -logger = Log() +logger = logging.getLogger("mellon") def compute_ls_time( diff --git a/mellon/conditional.py b/mellon/conditional.py index c283788..9840347 100644 --- a/mellon/conditional.py +++ b/mellon/conditional.py @@ -1,53 +1,129 @@ -from jax.numpy import dot, square, isnan, any +import logging +from jax.numpy import dot, square, isnan, any, eye +from jax.numpy import sum as arraysum +from jax.numpy import diag as diagonal from jax.numpy.linalg import cholesky from jax.scipy.linalg import solve_triangular -from .util import ensure_2d, stabilize, DEFAULT_JITTER, Log -from .base_predictor import Predictor, PredictorTime - - -logger = Log() - - -class _FullConditionalMean: +from .util import ensure_2d, stabilize, DEFAULT_JITTER, add_variance +from .base_predictor import Predictor, ExpPredictor, PredictorTime +from .decomposition import DEFAULT_SIGMA + + +logger = logging.getLogger("mellon") + + +def _get_L(x, cov_func, jitter=DEFAULT_JITTER, y_cov_factor=None): + K = cov_func(x, x) + K = add_variance(K, y_cov_factor, jitter=jitter) + L = cholesky(K) + if any(isnan(L)): + message = ( + f"Covariance not positively definite with jitter={jitter}. " + "Consider increasing the jitter for numerical stabilization." + ) + logger.error(message) + raise ValueError(message) + return L + + +def _check_covariance(obj): + if not hasattr(obj, "L"): + raise ValueError( + "The predictor was computed without covariance. " + "Recompute setting `with_uncertainty=True.`" + ) + + +def _check_uncertainty(obj): + if not hasattr(obj, "W"): + raise ValueError( + "The predictor was computed without uncertainty, e.g., using ADVI. " + "Recompute setting `with_uncertainty=True.` and define `pre_transformation_std`" + ", e.g., by using `optimizer='advi'`." + ) + + +def _sigma_to_y_cov_factor(sigma, y_cov_factor, n): + if sigma is None and y_cov_factor is None: + message = ( + "No input uncertainty specified. Make sure to set `sigma` or `pre_transformation_std`, " + 'e.g., by using `optimizer="advi", to quantify uncertainty of the prediction.' + ) + logger.error(message) + raise ValueError(message) + if y_cov_factor is not None and sigma is not None and any(sigma > 0): + raise ValueError( + "One can specify either `sigma` or `y_cov_factor` to describe input noise, but not both." + ) + + if y_cov_factor is None: + try: + y_cov_factor = diagonal(sigma) + except ValueError: + y_cov_factor = eye(n) * sigma + + return y_cov_factor + + +class _FullConditional: def __init__( self, x, y, mu, cov_func, - sigma=0, + L=None, + sigma=DEFAULT_SIGMA, jitter=DEFAULT_JITTER, + y_cov_factor=None, + y_is_mean=False, + with_uncertainty=False, ): - """ - The mean function of the conditioned Gaussian process. + R""" + The mean function of the conditioned Gaussian process (GP). :param x: The training instances. :type x: array-like :param y: The function value at each point in x. :type y: array-like - :param mu: The original Gaussian process mean. + :param mu: The original GP mean. :type mu: float - :param cov_func: The Gaussian process covariance function. + :param cov_func: The GP covariance function. :type cov_func: function - :param sigma: White moise standard deviation. Defaults to 0. + + :param L : A matrix such that :math:`L L^\top \approx K`, where :math:`K` is the + covariance matrix of the GP. + :type L : array-like or None + :param sigma: Noise standard deviation of the data we condition on. Defaults to 0. :type sigma: float :param jitter: A small amount to add to the diagonal for stability. Defaults to 1e-6. :type jitter: float - :return: conditional_mean - The conditioned Gaussian process mean function. + :param y_cov_factor: A matrix :math:`\Sigma_L` such that + :math:`\Sigma_L\cdot\Sigma_L` is the covaraince of `y`. + Only required if `with_uncertainty=True`. Defaults to None. + :type y_cov_factor: array-like + :param y_is_mean: Wether to consider y the GP mean or a noise measurment + subject to `sigma` or `y_cov_factor`. Has no effect if `L` is passed. + Defaults to False. + :type y_is_mean: bool + :param with_uncertainty: Wether to compute covariance functions and + predictive uncertainty. Defaults to False. + :type with_uncertainty: bool + :return: conditional_mean - The conditioned GP mean function. :rtype: function """ x = ensure_2d(x) - sigma2 = square(sigma) - K = cov_func(x, x) - sigma2 = max(sigma2, jitter) - L = cholesky(stabilize(K, jitter=sigma2)) - if any(isnan(L)): - message = ( - f"Covariance not positively definite with jitter={jitter}. " - "Consider increasing the jitter for numerical stabilization." - ) - logger.error(message) - raise ValueError(message) + + if L is None: + logger.info("Recomputing covariance decomposition for predictive function.") + if y_is_mean: + logger.debug("Assuming y is the mean of the GP.") + L = _get_L(x, cov_func, jitter) + else: + logger.debug("Assuming y is not the mean of the GP.") + y_cov_factor = _sigma_to_y_cov_factor(sigma, y_cov_factor, x.shape[0]) + sigma = None + L = _get_L(x, cov_func, jitter, y_cov_factor) r = y - mu weights = solve_triangular(L.T, solve_triangular(L, r, lower=True)) @@ -55,16 +131,25 @@ def __init__( self.x = x self.weights = weights self.mu = mu + self.jitter = jitter self.n_input_features = x.shape[1] + self.n_obs = x.shape[0] + + self._state_variables = {"x", "weights", "mu", "jitter"} + + if not with_uncertainty: + return - def _data_dict(self): - return { - "x": self.x, - "weights": self.weights, - "mu": self.mu, - } + self.L = L + self._state_variables.add("L") - def _predict(self, Xnew): + y_cov_factor = _sigma_to_y_cov_factor(sigma, y_cov_factor, x.shape[0]) + + W = solve_triangular(L.T, solve_triangular(L, y_cov_factor, lower=True)) + self.W = W + self._state_variables.add("W") + + def _mean(self, Xnew): cov_func = self.cov_func x = self.x weights = self.weights @@ -73,16 +158,53 @@ def _predict(self, Xnew): Kus = cov_func(Xnew, x) return mu + dot(Kus, weights) + def _covariance(self, Xnew, diag=True): + _check_covariance(self) + x = self.x + cov_func = self.cov_func + L = self.L + + Kus = cov_func(x, Xnew) + A = solve_triangular(L, Kus, lower=True) + if diag: + Kss = cov_func.diag(Xnew) + var = Kss - arraysum(square(A), axis=0) + return var + else: + Kss = cov_func(Xnew, Xnew) + cov = Kss - dot(A.T, A) + return cov + + def _mean_covariance(self, Xnew, diag=True): + _check_uncertainty(self) + cov_func = self.cov_func + x = self.x + W = self.W + + Kus = cov_func(Xnew, x) + cov_L = Kus @ W + + if diag: + var = arraysum(cov_L * cov_L, axis=1) + return var + else: + cov = cov_L @ cov_L.T + return cov -class FullConditionalMean(_FullConditionalMean, Predictor): + +class FullConditional(_FullConditional, Predictor): pass -class FullConditionalMeanTime(_FullConditionalMean, PredictorTime): +class ExpFullConditional(_FullConditional, ExpPredictor): pass -class _LandmarksConditionalMean: +class FullConditionalTime(_FullConditional, PredictorTime): + pass + + +class _LandmarksConditional: def __init__( self, x, @@ -90,10 +212,13 @@ def __init__( y, mu, cov_func, - sigma=0, + sigma=DEFAULT_SIGMA, jitter=DEFAULT_JITTER, + y_cov_factor=None, + y_is_mean=False, + with_uncertainty=False, ): - """ + R""" The mean function of the conditioned low rank gp, where rank is less than the number of landmark points. @@ -107,48 +232,71 @@ def __init__( :type mu: float :param cov_func: The Gaussian process covariance function. :type cov_func: function - :param sigma: White moise standard deviation. Defaults to 0. + :param sigma: Noise standard deviation of the data we condition on. Defaults to 0. :type sigma: float :param jitter: A small amount to add to the diagonal for stability. Defaults to 1e-6. :type jitter: float + :param y_cov_factor: A matrix :math:`\Sigma_L` such that + :math:`\Sigma_L\cdot\Sigma_L` is the covaraince of `y`. + Only required if `with_uncertainty=True`. Defaults to None. + :type y_cov_factor: array-like + :param y_is_mean: Wether to consider y the GP mean or a noise measurment + subject to `sigma` or `y_cov_factor`. Has no effect if `L` is passed. + Defaults to False. + :type y_is_mean: bool + :param with_uncertainty: Wether to compute predictive uncertainty and + intermediate covariance functions. Defaults to False. + :type with_uncertainty: bool :return: conditional_mean - The conditioned Gaussian process mean function. :rtype: function """ x = ensure_2d(x) xu = ensure_2d(xu) - sigma2 = square(sigma) - Kuu = cov_func(xu, xu) Kuf = cov_func(xu, x) - Luu = cholesky(stabilize(Kuu, jitter)) - if any(isnan(Luu)): - message = ( - f"Covariance of landmarks not positively definite with jitter={jitter}. " - "Consider increasing the jitter for numerical stabilization." - ) - logger.error(message) - raise ValueError(message) - A = solve_triangular(Luu, Kuf, lower=True) - sigma2 = max(sigma2, jitter) - L_B = cholesky(stabilize(dot(A, A.T), sigma2)) + L = _get_L(xu, cov_func, jitter) + A = solve_triangular(L, Kuf, lower=True) + + LLB = dot(A, A.T) + if y_is_mean: + logger.debug("Assuming y is the mean of the GP.") + LLB = stabilize(LLB, jitter) + else: + logger.debug("Assuming y is not the mean of the GP.") + y_cov_factor = _sigma_to_y_cov_factor(sigma, y_cov_factor, xu.shape[0]) + sigma = None + LLB = add_variance(LLB, y_cov_factor, jitter=jitter) + + L_B = cholesky(LLB) r = y - mu c = solve_triangular(L_B, dot(A, r), lower=True) z = solve_triangular(L_B.T, c) - weights = solve_triangular(Luu.T, z) + weights = solve_triangular(L.T, z) self.cov_func = cov_func self.landmarks = xu self.weights = weights self.mu = mu + self.jitter = jitter self.n_input_features = xu.shape[1] + self.n_obs = x.shape[0] + + self._state_variables = {"landmarks", "weights", "mu", "jitter"} + + if not with_uncertainty: + return + + self.L = L + self._state_variables.add("L") + + y_cov_factor = _sigma_to_y_cov_factor(sigma, y_cov_factor, xu.shape[0]) - def _data_dict(self): - return { - "landmarks": self.landmarks, - "weights": self.weights, - "mu": self.mu, - } + C = solve_triangular(L_B, dot(A, y_cov_factor), lower=True) + Z = solve_triangular(L_B.T, C) + W = solve_triangular(L.T, Z) + self.W = W + self._state_variables.add("W") - def _predict(self, Xnew): + def _mean(self, Xnew): cov_func = self.cov_func xu = self.landmarks weights = self.weights @@ -157,24 +305,65 @@ def _predict(self, Xnew): Kus = cov_func(Xnew, xu) return mu + dot(Kus, weights) + def _covariance(self, Xnew, diag=False): + _check_covariance(self) + cov_func = self.cov_func + landmarks = self.landmarks + L = self.L + + K = cov_func(landmarks, Xnew) + A = solve_triangular(L, K, lower=True) + + if diag: + Kss = cov_func.diag(Xnew) + var = Kss - arraysum(square(A), axis=0) + return var + else: + cov = cov_func(Xnew, Xnew) - dot(A.T, A) + return cov + + def _mean_covariance(self, Xnew, diag=True): + _check_uncertainty(self) + cov_func = self.cov_func + xu = self.landmarks + W = self.W + + Kus = cov_func(Xnew, xu) + cov_L = Kus @ W + + if diag: + var = arraysum(cov_L * cov_L, axis=1) + return var + else: + cov = cov_L @ cov_L.T + return cov -class LandmarksConditionalMean(_LandmarksConditionalMean, Predictor): + +class LandmarksConditional(_LandmarksConditional, Predictor): pass -class LandmarksConditionalMeanTime(_LandmarksConditionalMean, PredictorTime): +class ExpLandmarksConditional(_LandmarksConditional, ExpPredictor): pass -class _LandmarksConditionalMeanCholesky: +class LandmarksConditionalTime(_LandmarksConditional, PredictorTime): + pass + + +class _LandmarksConditionalCholesky: def __init__( self, xu, pre_transformation, mu, cov_func, - sigma=0, + n_obs, + L=None, + sigma=DEFAULT_SIGMA, jitter=DEFAULT_JITTER, + y_is_mean=False, + with_uncertainty=False, ): """ The mean function of the conditioned low rank gp, where rank @@ -188,41 +377,66 @@ def __init__( :type mu: float :param cov_func: The Gaussian process covariance function. :type cov_func: function - :param sigma: White moise standard deviation. Defaults to 0. + :param n_obs: The number of observations/cells trained on. Used for normalization. + :type n_obs: int + :param L : A matrix such that :math:`L L^\top \approx K`, where :math:`K` is the + covariance matrix of the Gaussian Process. + :type L : array-like or None + :param sigma: Standard deviation of `pre_transformation`. Defaults to 0. :type sigma: float :param jitter: A small amount to add to the diagonal for stability. Defaults to 1e-6. :type jitter: float + :param y_is_mean: Wether to consider y the GP mean or a noise measurment + subject to `sigma`. Has no effect if `L` is passed. + Defaults to False. + :type y_is_mean: bool + :param with_uncertainty: Wether to compute predictive uncertainty and + intermediate covariance functions. Defaults to False. + :type with_uncertainty: bool :return: conditional_mean - The conditioned Gaussian process mean function. :rtype: function """ xu = ensure_2d(xu) - sigma2 = square(sigma) - K = cov_func(xu, xu) - sigma2 = max(sigma2, jitter) - L = cholesky(stabilize(K, jitter=sigma2)) - if any(isnan(L)): - message = ( - f"Covariance not positively definite with jitter={jitter}. " - "Consider increasing the jitter for numerical stabilization." - ) - logger.error(message) - raise ValueError(message) + if L is None: + logger.info("Recomputing covariance decomposition for predictive function.") + if y_is_mean: + logger.debug("Assuming y is the mean of the GP.") + L = _get_L(xu, cov_func, jitter) + else: + logger.debug("Assuming y is not the mean of the GP.") + y_cov_factor = _sigma_to_y_cov_factor(sigma, None, xu.shape[0]) + sigma = None + L = _get_L(xu, cov_func, jitter, y_cov_factor) + weights = solve_triangular(L.T, pre_transformation) self.cov_func = cov_func self.landmarks = xu self.weights = weights self.mu = mu + self.jitter = jitter self.n_input_features = xu.shape[1] + self.n_obs = n_obs + + self._state_variables = {"landmarks", "weights", "mu", "jitter"} + + if not with_uncertainty: + return + + self.L = L + self._state_variables.add("L") + + try: + Stds = diagonal(sigma) + except ValueError: + # sigma seems to be scalar + Stds = eye(xu.shape[0]) * sigma - def _data_dict(self): - return { - "landmarks": self.landmarks, - "weights": self.weights, - "mu": self.mu, - } + W = solve_triangular(L.T, Stds) + self.W = W + self._state_variables.add("W") - def _predict(self, Xnew): + def _mean(self, Xnew): cov_func = self.cov_func xu = self.landmarks weights = self.weights @@ -231,12 +445,49 @@ def _predict(self, Xnew): Kus = cov_func(Xnew, xu) return mu + dot(Kus, weights) + def _covariance(self, Xnew, diag=True): + _check_covariance(self) + + cov_func = self.cov_func + landmarks = self.landmarks + L = self.L + + K = cov_func(landmarks, Xnew) + A = solve_triangular(L, K, lower=True) + + if diag: + Kss = cov_func.diag(Xnew) + var = Kss - arraysum(square(A), axis=0) + return var + else: + cov = cov_func(Xnew, Xnew) - dot(A.T, A) + return cov + + def _mean_covariance(self, Xnew, diag=True): + _check_uncertainty(self) + + cov_func = self.cov_func + xu = self.landmarks + W = self.W + + Kus = cov_func(Xnew, xu) + cov_L = Kus @ W + + if diag: + var = arraysum(cov_L * cov_L, axis=1) + return var + else: + cov = cov_L @ cov_L.T + return cov + + +class LandmarksConditionalCholesky(_LandmarksConditionalCholesky, Predictor): + pass + -class LandmarksConditionalMeanCholesky(_LandmarksConditionalMeanCholesky, Predictor): +class ExpLandmarksConditionalCholesky(_LandmarksConditionalCholesky, ExpPredictor): pass -class LandmarksConditionalMeanCholeskyTime( - _LandmarksConditionalMeanCholesky, PredictorTime -): +class LandmarksConditionalCholeskyTime(_LandmarksConditionalCholesky, PredictorTime): pass diff --git a/mellon/cov.py b/mellon/cov.py index 69df51d..fb117ba 100644 --- a/mellon/cov.py +++ b/mellon/cov.py @@ -259,7 +259,7 @@ class RatQuad(Covariance): Default is 1.0. alpha : float - The alpha parameter of the Rational Quadratic kernel. + The alpha parameter of the Rational Quadratic kernel. Default is 1.0. active_dims : array-like, slice or scalar, optional The indices of the active dimensions. If specified, the kernel function @@ -267,7 +267,7 @@ class RatQuad(Covariance): all dimensions are active. """ - def __init__(self, alpha, ls=1.0, active_dims=None): + def __init__(self, alpha=1.0, ls=1.0, active_dims=None): super().__init__() self.ls = ls self.alpha = alpha diff --git a/mellon/decomposition.py b/mellon/decomposition.py index 3f7923d..91c1baf 100644 --- a/mellon/decomposition.py +++ b/mellon/decomposition.py @@ -1,73 +1,50 @@ -from jax.numpy import cumsum, searchsorted, count_nonzero, sqrt, isnan, any +import logging +from jax.numpy import ( + cumsum, + searchsorted, + count_nonzero, + sqrt, + isnan, + any, + where, + square, +) from jax.numpy.linalg import eigh, cholesky, qr from jax.scipy.linalg import solve_triangular -from .util import stabilize, DEFAULT_JITTER, Log +from .util import stabilize, DEFAULT_JITTER -DEFAULT_RANK = 1.0 -DEFAULT_METHOD = "auto" +DEFAULT_RANK = 0.99 +DEFAULT_SIGMA = 0 -logger = Log() +logger = logging.getLogger("mellon") -def _check_method(rank, full, method): - R""" - Checks if rank is a float 0.0 :math:`\le` rank :math:`\le` 1.0 or an int - 1 :math:`\le` rank :math:`\le` full. Raises an error if neither is true - or if method doesn't match the detected method. - - :param rank: The rank of the decomposition, or if rank is a float greater - than 0 and less than 1, the rank is reduced further using the QR decomposition - such that the eigenvalues of the included eigenvectors account for the - specified percentage of the total eigenvalues. Defaults to 0.999. - :type rank: int or float - :param full: The size of the exact matrix. - :type full: int - :param method: The method to interpret the rank. - :type method: str - :return: method - The detected method. - :rtype: str - """ - - percent = isinstance(rank, float) and (0 < rank) and (rank <= 1) - fixed = isinstance(rank, int) and (1 <= rank) and (rank <= full) - if not (percent or fixed): - message = """rank must be a float 0.0 <=rank <= 1.0 or - an int 1 <= rank <= q. q equals the number of landmarks - or the number of data points if there are no landmarks.""" - raise ValueError(message) - elif percent and not (method == "percent" or method == "auto"): - message = f"""The argument method={method} does not match the rank={rank}. - The detected method from the rank is 'percent'.""" - raise ValueError(message) - elif fixed and not (method == "fixed" or method == "auto"): - message = f"""The argument method={method} does not match the rank={rank}. - The detected method from the rank is 'fixed'.""" - raise ValueError(message) - if percent: - return "percent" - else: - return "fixed" - - -def _eigendecomposition(A, rank=DEFAULT_RANK, method=DEFAULT_METHOD): +def _eigendecomposition(A, rank=DEFAULT_RANK): R""" Decompose :math:`A` into its largest positive `rank` and at least one eigenvector(s) and eigenvalue(s). - :param A: A square matrix. - :type A: array-like - :param rank: The rank of the decomposition, or if rank is a float + Parameters + ---------- + A : array-like + A square matrix. + rank : int or float, optional + The rank of the decomposition, or if rank is a float 0.0 :math:`\le` rank :math:`\le` 1.0, the rank is reduced further using the QR decomposition such that the eigenvalues of the included eigenvectors account for - the specified percentage of the total eigenvalues. Defaults to 0.999. - :type rank: int or float - :param method: Explicitly specifies whether rank is to be interpreted as a - fixed number of eigenvectors or a percent of eigenvalues to include - in the low rank approximation. - :type method: str - :return: :math:`s, v` - The top eigenvalues and eigenvectors. - :rtype: array-like, array-like + the specified percentage of the total eigenvalues. Defaults to 0.99. + + Returns + ------- + array-like, array-like + :math:`s, v` - The top eigenvalues and eigenvectors. + + Notes + ----- + If any eigenvalues are less than or equal to 0, a warning message will be logged, + indicating a singularity in the covariance matrix. Consider raising the jitter + value to address this issue. """ s, v = eigh(A) @@ -79,7 +56,7 @@ def _eigendecomposition(A, rank=DEFAULT_RANK, method=DEFAULT_METHOD): logger.warning(message) p = count_nonzero(s > 0) # stability summed = cumsum(s[: -p - 1 : -1]) - if method == "percent": + if isinstance(rank, float): # automatically choose rank to capture some percent of the eigenvalues target = summed[-1] * rank p = searchsorted(summed, target) @@ -91,7 +68,7 @@ def _eigendecomposition(A, rank=DEFAULT_RANK, method=DEFAULT_METHOD): p = 1 else: p = min(rank, p) - if (method == "percent" and rank < 1) or rank < len(summed): + if (isinstance(rank, float) and rank < 1) or rank < len(summed): frac = summed[p] / summed[-1] logger.info(f"Recovering {frac:%} variance in eigendecomposition.") s_ = s[-p:] @@ -99,21 +76,42 @@ def _eigendecomposition(A, rank=DEFAULT_RANK, method=DEFAULT_METHOD): return s_, v_ -def _full_rank(x, cov_func, jitter=DEFAULT_JITTER): +def _full_rank(x, cov_func, sigma=DEFAULT_SIGMA, jitter=DEFAULT_JITTER): R""" Compute :math:`L` such that :math:`L L^\top = K`, where :math:`K` is the full rank covariance matrix. - :param x: The training instances. - :type x: array-like - :param cov_func: The Gaussian process covariance function. - :type cov_func: function - :param jitter: A small amount to add to the diagonal. Defaults to 1e-6. - :type jitter: float - :return: :math:`L` - A matrix such that :math:`L L^\top = K`. - :rtype: array-like + Parameters + ---------- + x : array-like + The training instances. + cov_func : function + The Gaussian process covariance function. + sigma : float, optional + Noise standard deviation of the data we condition on. Defaults to 0. + jitter : float, optional + A small amount to add to the diagonal. Defaults to 1e-6. + + Returns + ------- + array-like + :math:`L` - A matrix such that :math:`L L^\top = K`. + + Raises + ------ + ValueError + If the covariance is not positively definite even with jitter, this error will be raised. + Consider increasing the jitter for numerical stabilization. + + Notes + ----- + If any NaN values are detected in `L`, an error message is logged, and a ValueError is raised, + indicating that the covariance is not positively definite with the given jitter value. """ - W = stabilize(cov_func(x, x), jitter) + sigma2 = square(sigma) + sigma2 = where(sigma2 < jitter, jitter, sigma2) + + W = stabilize(cov_func(x, x), sigma2) L = cholesky(W) if any(isnan(L)): message = ( @@ -126,107 +124,143 @@ def _full_rank(x, cov_func, jitter=DEFAULT_JITTER): def _full_decomposition_low_rank( - x, cov_func, rank=DEFAULT_RANK, method=DEFAULT_METHOD, jitter=DEFAULT_JITTER + x, + cov_func, + rank=DEFAULT_RANK, + sigma=DEFAULT_SIGMA, + jitter=DEFAULT_JITTER, ): R""" - Compute a low rank :math:`L` such that :math:`L L^\top ~= K`, where :math:`K` is the + Compute a low rank :math:`L` such that :math:`L L^\top \approx K`, where :math:`K` is the full rank covariance matrix. The rank is less than or equal to the number of landmark points. - :param x: The training instances. - :type x: array-like - :param cov_func: The Gaussian process covariance function. - :type cov_func: function - :param rank: The rank of the decomposition, or if rank is a float greater + Parameters + ---------- + x : array-like + The training instances. + cov_func : function + The Gaussian process covariance function. + rank : int or float, optional + The rank of the decomposition, or if rank is a float greater than 0 and less than 1, the eigenvalues of the included eigenvectors account for the specified percentage of the total eigenvalues. - Defaults to 0.999. - :type rank: int or float - :param jitter: A small amount to add to the diagonal. Defaults to 1e-6. - :type jitter: float - :param method: Explicitly specifies whether rank is to be interpreted as a - fixed number of eigenvectors or a percent of eigenvalues to include - in the low rank approximation. Supports 'fixed', 'percent', or 'auto'. - If 'auto', interprets rank as a fixed number of eigenvectors if it is - an int and interprets rank as a percent of eigenvalues if it is a float. - Defaults to 'auto'. - :type method: str - :return: :math:`L` - A matrix such that :math:`L L^\top \approx K`. - :rtype: array-like + Defaults to 0.99. + sigma : float, optional + Noise standard deviation of the data we condition on. Defaults to 0. + jitter : float, optional + A small amount to add to the diagonal. Defaults to 1e-6. + + Returns + ------- + array-like + :math:`L` - A matrix such that :math:`L L^\top \approx K`. + + Notes + ----- + The rank of the decomposition is determined by either the integer value provided or + automatically selected to capture the specified percentage of total eigenvalues if a float is + provided. This function computes the low-rank approximation of the full covariance matrix. """ - W = cov_func(x, x) - s, v = _eigendecomposition(W, rank=rank, method=method) + sigma2 = square(sigma) + sigma2 = where(sigma2 < jitter, jitter, sigma2) + + W = stabilize(cov_func(x, x), sigma2) + s, v = _eigendecomposition(W, rank=rank) L = v * sqrt(s) return L -def _standard_low_rank(x, cov_func, xu, jitter=DEFAULT_JITTER): +def _standard_low_rank( + x, cov_func, xu, Lp=None, sigma=DEFAULT_SIGMA, jitter=DEFAULT_JITTER +): R""" - Compute a low rank :math:`L` such that :math:`L L^\top \approx K`, where :math:`K` - is the full rank covariance matrix. The rank is equal to the number of - landmark points. + Compute a low rank :math:`L` such that :math:`L L^\top \approx K`, + where :math:`K` is the full rank covariance matrix on `x`, and + :math:`L_p L_p^\top = \Sigma_p` where :math:`\Sigma_p` is the full rank + covariance matrix on `xu`. The rank is equal to the number of landmark points. - :param x: The training instances. - :type x: array-like - :param cov_func: The Gaussian process covariance function. - :type cov_func: function - :param xu: The landmark points. - :type xu: array-like - :param jitter: A small amount to add to the diagonal. Defaults to 1e-6. - :type jitter: float - :return: :math:`L` - A matrix such that :math:`L L^\top \approx K`. - :rtype: array-like + Parameters + ---------- + x : array-like + The training instances. + cov_func : function + The Gaussian process covariance function. + xu : array-like + The landmark points. + Lp : array-like, optional + A matrix :math:`L_p L_p^\top = \Sigma_p` where :math:`\Sigma_p` is + the full rank covariance matrix on the landmarks `xu`. + Pass to avoid recomputing, by default None. + sigma : float, optional + Noise standard deviation of the data we condition on, by default 0. + jitter : float, optional + A small amount to add to the diagonal, by default 1e-6. + + Returns + ------- + array-like, array-like + :math:`L` - A matrix such that :math:`L L^\top \approx K`. """ - W = stabilize(cov_func(xu, xu), jitter) C = cov_func(x, xu) - U = cholesky(W) - if any(isnan(U)): - message = ( - f"Covariance of landmarks not positively definite with jitter={jitter}. " - "Consider increasing the jitter for numerical stabilization." - ) - logger.error(message) - raise ValueError(message) - L = solve_triangular(U, C.T, lower=True).T + + if Lp is None: + Lp = _full_rank(xu, cov_func, sigma=sigma, jitter=jitter) + L = solve_triangular(Lp, C.T, lower=True).T return L def _modified_low_rank( - x, cov_func, xu, rank=DEFAULT_RANK, method=DEFAULT_METHOD, jitter=DEFAULT_JITTER + x, + cov_func, + xu, + rank=DEFAULT_RANK, + sigma=DEFAULT_SIGMA, + jitter=DEFAULT_JITTER, ): R""" - Compute a low rank :math:`L` such that :math:`L L^\top ~= K`, where :math:`K` is the - full rank covariance matrix. The rank is less than or equal to the number of - landmark points. + Compute a low rank :math:`L` and :math:`L_p` such that :math:`L L^\top \approx K`, + where :math:`K` is the full rank covariance matrix on `x`. + The rank is equal to the number of landmark points. This is the improved + Nyström rank reduction method. - :param x: The training instances. - :type x: array-like - :param cov_func: The Gaussian process covariance function. - :type cov_func: function - :param xu: The landmark points. - :type xu: array-like - :param rank: The rank of the decomposition, or if rank is a float + Parameters + ---------- + x : array-like + The training instances. + cov_func : function + The Gaussian process covariance function. + xu : array-like + The landmark points. + rank : int or float, optional + The rank of the decomposition, or if rank is a float 0.0 :math:`\le` rank :math:`\le` 1.0, the rank is reduced further using the QR decomposition such that the eigenvalues of the included eigenvectors - account for the specified percentage of the total eigenvalues. Defaults to 0.999. - :type rank: int or float - :param jitter: A small amount to add to the diagonal. Defaults to 1e-6. - :type jitter: float - :param method: Explicitly specifies whether rank is to be interpreted as a - fixed number of eigenvectors or a percent of eigenvalues to include - in the low rank approximation. Supports 'fixed', 'percent', or 'auto'. - If 'auto', interprets rank as a fixed number of eigenvectors if it is - an int and interprets rank as a percent of eigenvalues if it is a float. - Defaults to 'auto'. - :type method: str - :return: :math:`L` - A matrix such that :math:`L L^\top \approx K`. - :rtype: array-like + account for the specified percentage of the total eigenvalues. Defaults to 0.99. + sigma : float, optional + Noise standard deviation of the data we condition on. Defaults to 0. + jitter : float, optional + A small amount to add to the diagonal. Defaults to 1e-6. + + Returns + ------- + array-like + :math:`L` - A matrix such that :math:`L L^\top \approx K`. + + Notes + ----- + This function computes a low-rank approximation of the full covariance matrix using + an improved Nyström method. The rank reduction is controlled either by an integer value or + a floating-point value that specifies the percentage of total eigenvalues. """ - W = stabilize(cov_func(xu, xu), jitter) + sigma2 = square(sigma) + sigma2 = where(sigma2 < jitter, jitter, sigma2) + + W = stabilize(cov_func(xu, xu), sigma2) C = cov_func(x, xu) Q, R = qr(C, mode="reduced") - s, v = _eigendecomposition(W, rank=xu.shape[0], method="fixed") + s, v = _eigendecomposition(W, rank=xu.shape[0]) T = R @ v - S, V = _eigendecomposition(T / s @ T.T, rank=rank, method=method) + S, V = _eigendecomposition(T / s @ T.T, rank=rank) L = Q @ V * sqrt(S) return L diff --git a/mellon/density_estimator.py b/mellon/density_estimator.py index ddf1adf..289a2bc 100644 --- a/mellon/density_estimator.py +++ b/mellon/density_estimator.py @@ -1,10 +1,10 @@ -from .decomposition import DEFAULT_RANK, DEFAULT_METHOD +import logging from .base_model import BaseEstimator, DEFAULT_COV_FUNC from .inference import ( compute_transform, compute_loss_func, compute_log_density_x, - compute_conditional_mean, + compute_conditional, DEFAULT_N_ITER, DEFAULT_INIT_LEARN_RATE, DEFAULT_JIT, @@ -15,11 +15,9 @@ compute_d_factal, compute_mu, compute_initial_value, - DEFAULT_N_LANDMARKS, ) from .util import ( DEFAULT_JITTER, - Log, ) from .validation import ( _validate_string, @@ -29,7 +27,7 @@ DEFAULT_D_METHOD = "embedding" -logger = Log() +logger = logging.getLogger("mellon") class DensityEstimator(BaseEstimator): @@ -48,24 +46,34 @@ class DensityEstimator(BaseEstimator): Defaults to Matern52. n_landmarks : int - The number of landmark points. If less than 1 or greater than or equal to the - number of training points, inducing points will not be computed or used. - Defaults to 5000. + The number of landmark/inducing points. Only used if a sparse GP is indicated + through gp_type. If 0 or equal to the number of training points, inducing points + will not be computed or used. Defaults to 5000. rank : int or float - The rank of the approximate covariance matrix. If rank is an int, an :math:`n \times` + The rank of the approximate covariance matrix for the Nyström rank reduction. + If rank is an int, an :math:`n \times` rank matrix :math:`L` is computed such that :math:`L L^\top \approx K`, where `K` is the exact :math:`n \times n` covariance matrix. If rank is a float 0.0 :math:`\le` rank :math:`\le` 1.0, the rank/size of :math:`L` is selected such that the included eigenvalues of the covariance between landmark points account for the specified percentage of the sum - of eigenvalues. Defaults to 0.99. - - method : str - Determines how the rank is interpreted: as a fixed number of eigenvectors ('fixed'), a - percent of eigenvalues ('percent'), or automatically ('auto'). If 'auto', the rank is - interpreted as a fixed number of eigenvectors if it is an int and as a percent of - eigenvalues if it is a float. This parameter is provided for clarity in the ambiguous case - of 1 vs 1.0. Defaults to 'auto'. + of eigenvalues. It is ignored if gp_type does not indicate a Nyström rank reduction. + Defaults to 0.99. + + gp_type : str or GaussianProcessType + The type of sparcification used for the Gaussian Process + - 'full' None-sparse Gaussian Process + - 'full_nystroem' Sparse GP with Nyström rank reduction without landmarks, + which lowers the computational complexity. + - 'sparse_cholesky' Sparse GP using landmarks/inducing points, + typically employed to enable scalable GP models. + - 'sparse_nystroem' Sparse GP using landmarks or inducing points, + along with an improved Nyström rank reduction method. + + The value can be either a string matching one of the above options or an instance of + the `mellon.util.GaussianProcessType` Enum. If a partial match is found with the + Enum, a warning will be logged, and the closest match will be used. + Defaults to 'sparse_cholesky'. d_method : str The method to compute the intrinsic dimensionality of the data. Implemented options are @@ -115,11 +123,21 @@ class DensityEstimator(BaseEstimator): the geometric mean of the nearest neighbor distances times a constant. If `cov_func` is supplied explicitly, `ls` has no effect. Defaults to None. - cov_func : function or None - The Gaussian process covariance function of the form k(x, y) :math:`\rightarrow` float. + ls_factor : float, optional + A scaling factor applied to the length scale when it's automatically + selected. It is used to manually adjust the automatically chosen length + scale for finer control over the model's sensitivity to variations in the data. + + cov_func : mellon.Covaraince or None + The Gaussian process covariance function as instance of :class:`mellon.Covaraince`. If None, the covariance function `cov_func` is automatically generated as `cov_func_curry(ls)`. Defaults to None. + Lp : array-like or None + A matrix such that :math:`L_p L_p^\top = \Sigma_p`, where :math:`\Sigma_p` is the + covariance matrix of the inducing points (all cells in non-sparse GP). + Not used when Nyström rank reduction is employed. Defaults to None. + L : array-like or None A matrix such that :math:`L L^\top \approx K`, where :math:`K` is the covariance matrix. If None, `L` is computed automatically. Defaults to None. @@ -130,17 +148,39 @@ class DensityEstimator(BaseEstimator): - (d/2) \cdot \log(\pi) - d \cdot \log(\text{nn_distances})` and :math:`d` is the intrinsic dimensionality of the data. Defaults to None. + predictor_with_uncertainty : bool + If set to True, computes the predictor instance `.predict` with its predictive uncertainty. + The uncertainty comes from two sources: + + 1) `.predict.mean_covariance`: + Uncertainty arising from the posterior distribution of the Bayesian inference. + This component quantifies uncertainties inherent in the model's parameters and structure. + Available only if `.pre_transformation_std` is defined (e.g., using `optimizer="advi"`), + which reflects the standard deviation of the latent variables before transformation. + + 2) `.predict.covariance`: + Uncertainty for out-of-bag states originating from the compressed function representation + in the Gaussian Process. Specifically, this uncertainty corresponds to locations that are + not inducing points of the Gaussian Process and represents the covariance of the + conditional normal distribution. + jit : bool Use jax just-in-time compilation for loss and its gradient during optimization. Defaults to False. + + check_rank : bool + Weather to check if landmarks allow sufficient complexity by checking the approximate + rank of the covariance matrix. This only applies to the non-Nyström gp_types. + If set to None the rank check is only performed if n_landmarks >= n_samples/10. + Defaults to None. """ def __init__( self, cov_func_curry=DEFAULT_COV_FUNC, - n_landmarks=DEFAULT_N_LANDMARKS, - rank=DEFAULT_RANK, - method=DEFAULT_METHOD, + n_landmarks=None, + rank=None, + gp_type=None, d_method=DEFAULT_D_METHOD, jitter=DEFAULT_JITTER, optimizer=DEFAULT_OPTIMIZER, @@ -153,15 +193,19 @@ def __init__( ls=None, ls_factor=1, cov_func=None, + Lp=None, L=None, initial_value=None, + predictor_with_uncertainty=False, jit=DEFAULT_JIT, + check_rank=None, ): super().__init__( cov_func_curry=cov_func_curry, n_landmarks=n_landmarks, rank=rank, jitter=jitter, + gp_type=gp_type, optimizer=optimizer, n_iter=n_iter, init_learn_rate=init_learn_rate, @@ -172,9 +216,12 @@ def __init__( ls=ls, ls_factor=ls_factor, cov_func=cov_func, + Lp=Lp, L=L, initial_value=initial_value, + predictor_with_uncertainty=predictor_with_uncertainty, jit=jit, + check_rank=check_rank, ) self.d_method = _validate_string( d_method, "d_method", choices={"fractal", "embedding"} @@ -184,52 +231,21 @@ def __init__( self.opt_state = None self.losses = None self.pre_transformation = None + self.pre_transformation_std = None self.log_density_x = None self.log_density_func = None - def __repr__(self): - name = self.__class__.__name__ - string = ( - f"{name}(" - f"cov_func_curry={self.cov_func_curry}, " - f"n_landmarks={self.n_landmarks}, " - f"rank={self.rank}, " - f"method='{self.method}', " - f"jitter={self.jitter}, " - f"optimizer='{self.optimizer}', " - f"n_iter={self.n_iter}, " - f"init_learn_rate={self.init_learn_rate}, " - f"landmarks={self.landmarks}, " - ) - if self.nn_distances is None: - string += "nn_distances=None, " - else: - string += "nn_distances=nn_distances, " - string += ( - f"d={self.d}, " - f"mu={self.mu}, " - f"ls={self.ls}, " - f"cov_func={self.cov_func}, " - ) - if self.L is None: - string += "L=None, " - else: - string += "L=L, " - if self.initial_value is None: - string += "initial_value=None, " - else: - string += "initial_value=initial_value, " - string += f"jit={self.jit}" ")" - return string - def _compute_d(self): x = self.x if self.d_method == "fractal": - logger.warning("Using EXPERIMENTAL fractal dimensionality selection.") d = compute_d_factal(x) logger.info(f"Using d={d}.") else: d = compute_d(x) + logger.info( + f"Using embedding dimensionality d={d}. " + 'Use d_method="fractal" to enable effective density normalization.' + ) if d > 50: message = f"""The detected dimensionality of the data is over 50, which is likely to cause numerical instability issues. @@ -277,19 +293,29 @@ def _set_log_density_func(self): x = self.x landmarks = self.landmarks pre_transformation = self.pre_transformation + pre_transformation_std = self.pre_transformation_std log_density_x = self.log_density_x mu = self.mu cov_func = self.cov_func + L = self.L + Lp = self.Lp jitter = self.jitter + with_uncertainty = self.predictor_with_uncertainty logger.info("Computing predictive function.") - log_density_func = compute_conditional_mean( + log_density_func = compute_conditional( x, landmarks, pre_transformation, + pre_transformation_std, log_density_x, mu, cov_func, + L, + Lp, + sigma=None, jitter=jitter, + y_is_mean=True, + with_uncertainty=with_uncertainty, ) self.log_density_func = log_density_func @@ -318,12 +344,17 @@ def prepare_inference(self, x): raise ValueError(message) x = self.set_x(x) + self._prepare_attribute("n_landmarks") + self._prepare_attribute("rank") + self._prepare_attribute("gp_type") + self._validate_parameter() self._prepare_attribute("nn_distances") self._prepare_attribute("d") self._prepare_attribute("mu") self._prepare_attribute("ls") self._prepare_attribute("cov_func") self._prepare_attribute("landmarks") + self._prepare_attribute("Lp") self._prepare_attribute("L") self._prepare_attribute("initial_value") self._prepare_attribute("transform") @@ -449,7 +480,7 @@ def fit_predict(self, x=None, build_predict=False): array-like The log density at each training point in `x`. """ - if self.x is not None and self.x is not x: + if self.x is not None and x is not None and self.x is not x: message = "self.x has been set already, but is not equal to the argument x." error = ValueError(message) logger.error(error) diff --git a/mellon/dimensionality_estimator.py b/mellon/dimensionality_estimator.py index ec4f71d..3756c20 100644 --- a/mellon/dimensionality_estimator.py +++ b/mellon/dimensionality_estimator.py @@ -1,11 +1,11 @@ -from .decomposition import DEFAULT_RANK, DEFAULT_METHOD +import logging from .base_model import BaseEstimator, DEFAULT_COV_FUNC from .inference import ( compute_dimensionality_transform, compute_dimensionality_loss_func, compute_log_density_x, - compute_conditional_mean, - compute_conditional_mean_explog, + compute_conditional, + compute_conditional_explog, DEFAULT_N_ITER, DEFAULT_INIT_LEARN_RATE, DEFAULT_JIT, @@ -15,12 +15,11 @@ compute_distances, compute_mu, compute_initial_dimensionalities, - DEFAULT_N_LANDMARKS, ) from .util import ( DEFAULT_JITTER, - Log, local_dimensionality, + object_str, ) from .validation import ( _validate_positive_int, @@ -29,7 +28,7 @@ ) -logger = Log() +logger = logging.getLogger("mellon") class DimensionalityEstimator(BaseEstimator): @@ -47,25 +46,33 @@ class DimensionalityEstimator(BaseEstimator): a curry function taking one length scale argument and returning a covariance function of the form k(x, y) :math:`\rightarrow` float. - n_landmarks: int, optional (default=5000) - The number of landmark points. If less than 1 or greater than or equal - to the number of training points, - inducing points are not computed or used. + n_landmarks : int, optional (default=5000) + The number of landmark/inducing points. Only used if a sparse GP is indicated + through gp_type. If 0 or equal to the number of training points, inducing points + will not be computed or used. rank: int or float, optional (default=0.99) - The rank of the approximate covariance matrix. When interpreted as an - integer, an :math:`n \times` rank matrix - :math:`L` is computed such that :math:`L L^\top \approx K`, where - :math:`K` is the exact :math:`n \times n` covariance matrix. - When interpreted as a float (between 0.0 and 1.0), the rank/size of - :math:`L` is chosen such that the included eigenvalues of the covariance - between landmark points account for the specified percentage of the total eigenvalues. - - method: str, optional (default='auto') - Determines whether the `rank` parameter is interpreted as a fixed - number of eigenvectors ('fixed'), a percentage of eigenvalues ('percent'), - or determined automatically ('auto'). In 'auto' mode, `rank` is treated - as a fixed number if it is an integer, or a percentage if it's a float. + The rank of the approximate covariance matrix for the Nyström rank reduction. + If rank is an int, an :math:`n \times` + rank matrix :math:`L` is computed such that :math:`L L^\top \approx K`, where `K` is the + exact :math:`n \times n` covariance matrix. If rank is a float 0.0 :math:`\le` rank + :math:`\le` 1.0, the rank/size of :math:`L` is selected such that the included eigenvalues + of the covariance between landmark points account for the specified percentage of the sum + of eigenvalues. It is ignored if gp_type does not indicate a Nyström rank reduction. + + gp_type : str or GaussianProcessType, optional (default='sparse_cholesky') + The type of sparcification used for the Gaussian Process: + - 'full' None-sparse Gaussian Process + - 'full_nystroem' Sparse GP with Nyström rank reduction without landmarks, + which lowers the computational complexity. + - 'sparse_cholesky' Sparse GP using landmarks/inducing points, + typically employed to enable scalable GP models. + - 'sparse_nystroem' Sparse GP using landmarks or inducing points, + along with an improved Nyström rank reduction method. + + The value can be either a string matching one of the above options or an instance of + the `mellon.util.GaussianProcessType` Enum. If a partial match is found with the + Enum, a warning will be logged, and the closest match will be used. jitter: float, optional (default=1e-6) A small amount added to the diagonal of the covariance matrix to ensure @@ -118,10 +125,15 @@ class DimensionalityEstimator(BaseEstimator): selected. It is used to manually adjust the automatically chosen length scale for finer control over the model's sensitivity to variations in the data. - cov_func: function or None, optional - The Gaussian process covariance function of the form k(x, y) - :math:`\rightarrow` float. If None, the covariance function is generated - automatically as `cov_func = cov_func_curry(ls)`. + cov_func : mellon.Covaraince or None + The Gaussian process covariance function as instance of :class:`mellon.Covaraince`. + If None, the covariance function `cov_func` is automatically generated as `cov_func_curry(ls)`. + Defaults to None. + + Lp : array-like or None + A matrix such that :math:`L_p L_p^\top = \Sigma_p`, where :math:`\Sigma_p` is the + covariance matrix of the inducing points (all cells in non-sparse GP). + Not used when Nyström rank reduction is employed. Defaults to None. L: array-like or None, optional A matrix such that :math:`L L^\top \approx K`, where :math:`K` is the @@ -134,16 +146,38 @@ class DimensionalityEstimator(BaseEstimator): for density initialization and the neighborhood-based local intrinsic dimensionality for dimensionality initialization. + predictor_with_uncertainty : bool + If set to True, computes the predictor instances `.predict` and `.predict_density` + with its predictive uncertainty. The uncertainty comes from two sources: + + 1) `.predict.mean_covariance`: + Uncertainty arising from the posterior distribution of the Bayesian inference. + This component quantifies uncertainties inherent in the model's parameters and structure. + Available only if `.pre_transformation_std` is defined (e.g., using `optimizer="advi"`), + which reflects the standard deviation of the latent variables before transformation. + + 2) `.predict.covariance`: + Uncertainty for out-of-bag states originating from the compressed function representation + in the Gaussian Process. Specifically, this uncertainty corresponds to locations that are + not inducing points of the Gaussian Process and represents the covariance of the + conditional normal distribution. + jit: bool, optional (default=False) If True, use JAX's just-in-time compilation for loss and its gradient during optimization. + + check_rank : bool + Weather to check if landmarks allow sufficient complexity by checking the approximate + rank of the covariance matrix. This only applies to the non-Nyström gp_types. + If set to None the rank check is only performed if n_landmarks >= n_samples/10. + Defaults to None. """ def __init__( self, cov_func_curry=DEFAULT_COV_FUNC, - n_landmarks=DEFAULT_N_LANDMARKS, - rank=DEFAULT_RANK, - method=DEFAULT_METHOD, + n_landmarks=None, + rank=None, + gp_type=None, jitter=DEFAULT_JITTER, optimizer=DEFAULT_OPTIMIZER, n_iter=DEFAULT_N_ITER, @@ -157,15 +191,18 @@ def __init__( ls=None, ls_factor=1, cov_func=None, + Lp=None, L=None, initial_value=None, + predictor_with_uncertainty=False, jit=DEFAULT_JIT, + check_rank=None, ): super().__init__( cov_func_curry=cov_func_curry, n_landmarks=n_landmarks, rank=rank, - method=method, + gp_type=gp_type, jitter=jitter, optimizer=optimizer, n_iter=n_iter, @@ -177,9 +214,12 @@ def __init__( ls=ls, ls_factor=ls_factor, cov_func=cov_func, + Lp=Lp, L=L, initial_value=initial_value, + predictor_with_uncertainty=predictor_with_uncertainty, jit=jit, + check_rank=check_rank, ) self.k = _validate_positive_int(k, "k") self.mu_dim = _validate_float(mu_dim, "mu_dim") @@ -190,6 +230,7 @@ def __init__( self.opt_state = None self.losses = None self.pre_transformation = None + self.pre_transformation_std = None self.local_dim_x = None self.log_density_x = None self.local_dim_func = None @@ -197,38 +238,36 @@ def __init__( def __repr__(self): name = self.__class__.__name__ + landmarks = object_str(self.landmarks, ["landmarks", "dims"]) + Lp = object_str(self.Lp, ["landmarks", "landmarks"]) + L = object_str(self.L, ["cells", "ranks"]) + nn_distances = object_str(self.nn_distances, ["cells"]) + d = object_str(self.d, ["cells"]) + initial_value = object_str(self.initial_value, ["functions", "ranks"]) string = ( f"{name}(" - f"cov_func_curry={self.cov_func_curry}, " - f"n_landmarks={self.n_landmarks}, " - f"rank={self.rank}, " - f"method='{self.method}', " - f"jitter={self.jitter}, " - f"optimizer='{self.optimizer}', " - f"n_iter={self.n_iter}, " - f"init_learn_rate={self.init_learn_rate}, " - f"landmarks={self.landmarks}, " + f"\n cov_func_curry={self.cov_func_curry}," + f"\n n_landmarks={self.n_landmarks}," + f"\n rank={self.rank}," + f"\n gp_type={self.gp_type}," + f"\n jitter={self.jitter}, " + f"\n optimizer={self.optimizer}," + f"\n landmarks={landmarks}," + f"\n nn_distances={nn_distances}," + f"\n d={d}," + f"\n mu_dim={self.mu_dim}," + f"\n mu_dens={self.mu_dens}," + f"\n ls={self.ls}," + f"\n ls_factor={self.ls_factor}," + f"\n cov_func={self.cov_func}," + f"\n Lp={Lp}," + f"\n L={L}," + f"\n initial_value={initial_value}," + f"\n predictor_with_uncertainty={self.predictor_with_uncertainty}," + f"\n jit={self.jit}," + f"\n check_rank={self.check_rank}," + "\n)" ) - if self.distances is None: - string += "distances=None, " - else: - string += "distances=distances, " - string += ( - f"d={self.d}, " - f"mu_dim={self.mu_dim}, " - f"mu_dens={self.mu_dens}, " - f"ls={self.ls}, " - f"cov_func={self.cov_func}, " - ) - if self.L is None: - string += "L=None, " - else: - string += "L=L, " - if self.initial_value is None: - string += "initial_value=None, " - else: - string += "initial_value=initial_value, " - string += f"jit={self.jit}" ")" return string def _compute_mu_dens(self): @@ -291,38 +330,64 @@ def _set_local_dim_x(self): def _set_local_dim_func(self): x = self.x landmarks = self.landmarks + pre_transformation = self.pre_transformation[0, :] + pre_transformation_std = self.pre_transformation_std + if pre_transformation_std is not None: + pre_transformation_std = pre_transformation_std[0, :] local_dim_x = self.local_dim_x mu = self.mu_dim cov_func = self.cov_func + L = self.L + Lp = self.Lp jitter = self.jitter + with_uncertainty = self.predictor_with_uncertainty logger.info("Computing predictive dimensionality function.") - log_dim_func = compute_conditional_mean_explog( + log_dim_func = compute_conditional_explog( x, landmarks, + pre_transformation, + pre_transformation_std, local_dim_x, mu, cov_func, + L, + Lp, + sigma=None, jitter=jitter, + y_is_mean=True, + with_uncertainty=with_uncertainty, ) self.local_dim_func = log_dim_func def _set_log_density_func(self): x = self.x landmarks = self.landmarks - pre_transformation = self.pre_transformation + pre_transformation = self.pre_transformation[1, :] + pre_transformation_std = self.pre_transformation_std + if pre_transformation_std is not None: + pre_transformation_std = pre_transformation_std[1, :] log_density_x = self.log_density_x mu = self.mu_dens cov_func = self.cov_func + L = self.L + Lp = self.Lp jitter = self.jitter + with_uncertainty = self.predictor_with_uncertainty logger.info("Computing predictive density function.") - log_density_func = compute_conditional_mean( + log_density_func = compute_conditional( x, landmarks, pre_transformation, + pre_transformation_std, log_density_x, mu, cov_func, + L, + Lp, + sigma=None, jitter=jitter, + y_is_mean=True, + with_uncertainty=with_uncertainty, ) self.log_density_func = log_density_func @@ -351,6 +416,10 @@ def prepare_inference(self, x): raise ValueError(message) x = self.set_x(x) + self._prepare_attribute("n_landmarks") + self._prepare_attribute("rank") + self._prepare_attribute("gp_type") + self._validate_parameter() self._prepare_attribute("distances") self._prepare_attribute("nn_distances") self._prepare_attribute("d") @@ -358,6 +427,7 @@ def prepare_inference(self, x): self._prepare_attribute("ls") self._prepare_attribute("cov_func") self._prepare_attribute("landmarks") + self._prepare_attribute("Lp") self._prepare_attribute("L") self._prepare_attribute("initial_value") self._prepare_attribute("transform") @@ -439,18 +509,16 @@ def fit(self, x=None, build_predict=True): @property def predict_density(self): """ - Predicts the log density with an adaptive unit for each data point in `x`. - The unit of density depends on the dimensionality of the data. + A property that returns an instance of the :class:`mellon.Predictor` class. This predictor can + be used to predict the log density for new data points by calling the instance like a function. - Parameters - ---------- - x : array-like of shape (n_samples, n_features) - The new data for which to predict the log density. + The predictor instance also supports serialization features, which allow for saving and loading + the predictor's state. For more details, refer to the :class:`mellon.Predictor` documentation. Returns ------- - array-like - The predicted log density for each test point in `x`. + mellon.Predictor + A predictor instance that computes the log density at each new data point. Example ------- @@ -465,22 +533,18 @@ def predict_density(self): @property def predict(self): """ - Returns an instance of the :class:`mellon.Predictor` class, which predicts the dimensionality - at each point in `x`. - - This instance includes a __call__ method, which can be used to predict the dimensionality. - The instance also supports serialization features, allowing for saving and loading the predictor's - state. For more details, refer to :class:`mellon.Predictor`. + A property that returns an instance of the :class:`mellon.base_predictor.ExpPredictor` class. + This predictor can be used to predict the dimensionality for new data points by calling + the instance like a function. - Parameters - ---------- - x : array-like of shape (n_samples, n_features) - The new data for which to predict the dimensionality. + The predictor instance also supports serialization features, which allow for saving and loading + the predictor's state. For more details, refer to the :class:`mellon.base_predictor.ExpPredictor` + documentation. Returns ------- - array-like - The predicted dimensionality for each test point in `x`. + mellon.base_predictor.ExpPredictor + A predictor instance that computes the dimensionality at each new data point. Example ------- @@ -519,7 +583,7 @@ def fit_predict(self, x=None, build_predict=False): ValueError If the argument `x` does not match `self.x` which was already set in a previous operation. """ - if self.x is not None and self.x is not x: + if self.x is not None and x is not None and self.x is not x: message = "self.x has been set already, but is not equal to the argument x." error = ValueError(message) logger.error(error) diff --git a/mellon/function_estimator.py b/mellon/function_estimator.py index 24aa61a..20ee6dd 100644 --- a/mellon/function_estimator.py +++ b/mellon/function_estimator.py @@ -1,28 +1,26 @@ -from .decomposition import DEFAULT_METHOD +import logging from .base_model import BaseEstimator, DEFAULT_COV_FUNC from .inference import ( - compute_conditional_mean, + compute_conditional, DEFAULT_N_ITER, DEFAULT_INIT_LEARN_RATE, DEFAULT_OPTIMIZER, ) -from .parameters import ( - DEFAULT_N_LANDMARKS, -) from .util import ( DEFAULT_JITTER, - Log, + GaussianProcessType, ) from .validation import ( - _validate_positive_float, + _validate_float_or_iterable_numerical, _validate_float, _validate_array, + _validate_bool, ) DEFAULT_D_METHOD = "embedding" -logger = Log() +logger = logging.getLogger("mellon") class FunctionEstimator(BaseEstimator): @@ -36,9 +34,21 @@ class FunctionEstimator(BaseEstimator): A curry that takes one length scale argument and returns a covariance function of the form k(x, y) :math:`\rightarrow` float. Defaults to Matern52. - n_landmarks : int, optional - The number of landmark points. If less than 1 or greater than or equal to the - number of training points, inducing points will not be computed or used. Defaults to 5000. + n_landmarks : int + The number of landmark/inducing points. Only used if a sparse GP is indicated + through gp_type. If 0 or equal to the number of training points, inducing points + will not be computed or used. Defaults to 5000. + + gp_type : str or GaussianProcessType + The type of sparcification used for the Gaussian Process: + - 'full' None-sparse Gaussian Process + - 'sparse_cholesky' Sparse GP using landmarks/inducing points, + typically employed to enable scalable GP models. + + The value can be either a string matching one of the above options or an instance of + the `mellon.util.GaussianProcessType` Enum. If a partial match is found with the + Enum, a warning will be logged, and the closest match will be used. + Defaults to 'sparse_cholesky'. jitter : float, optional A small amount added to the diagonal of the covariance matrix to ensure numerical stability. @@ -68,14 +78,32 @@ class FunctionEstimator(BaseEstimator): selected. It is used to manually adjust the automatically chosen length scale for finer control over the model's sensitivity to variations in the data. - cov_func : function or None, optional - The Gaussian process covariance function of the form k(x, y) :math:`\rightarrow` float. - If None, automatically generates the covariance function cov_func = cov_func_curry(ls). + cov_func : mellon.Covaraince or None + The Gaussian process covariance function as instance of :class:`mellon.Covaraince`. + If None, the covariance function `cov_func` is automatically generated as `cov_func_curry(ls)`. Defaults to None. sigma : float, optional The standard deviation of the white noise. Defaults to 0. + y_is_mean : bool + Wether to consider y the GP mean or a noise measurment + subject to `sigma` or `y_cov_factor`. Has no effect if `L` is passed. + Defaults to False. + + predictor_with_uncertainty : bool + If set to True, computes the predictor instance `.predict` with its predictive uncertainty. + The uncertainty comes from two sources: + + 1) `.predict.mean_covariance`: + Uncertainty arising from the input noise `sigma`. + + 2) `.predict.covariance`: + Uncertainty for out-of-bag states originating from the compressed function representation + in the Gaussian Process. Specifically, this uncertainty corresponds to locations that are + not inducing points of the Gaussian Process and represents the covariance of the + conditional normal distribution. + jit : bool, optional Use JAX just-in-time compilation for the loss function and its gradient during optimization. Defaults to False. @@ -84,8 +112,8 @@ class FunctionEstimator(BaseEstimator): def __init__( self, cov_func_curry=DEFAULT_COV_FUNC, - n_landmarks=DEFAULT_N_LANDMARKS, - method=DEFAULT_METHOD, + n_landmarks=None, + gp_type=None, jitter=DEFAULT_JITTER, optimizer=DEFAULT_OPTIMIZER, n_iter=DEFAULT_N_ITER, @@ -97,6 +125,8 @@ def __init__( ls_factor=1, cov_func=None, sigma=0, + y_is_mean=False, + predictor_with_uncertainty=False, jit=True, ): super().__init__( @@ -104,16 +134,32 @@ def __init__( n_landmarks=n_landmarks, rank=1.0, jitter=jitter, + gp_type=gp_type, landmarks=landmarks, nn_distances=nn_distances, mu=mu, ls=ls, ls_factor=ls_factor, cov_func=cov_func, + predictor_with_uncertainty=predictor_with_uncertainty, jit=jit, ) + self.y_is_mean = _validate_bool(y_is_mean, "y_is_mean") self.mu = _validate_float(mu, "mu") - self.sigma = _validate_positive_float(sigma, "sigma") + self.sigma = _validate_float_or_iterable_numerical( + sigma, "sigma", positive=True + ) + if ( + self.gp_type == GaussianProcessType.FULL_NYSTROEM + or self.gp_type == GaussianProcessType.SPARSE_NYSTROEM + ): + message = ( + f"gp_type={gp_type} but the Nyström rank reduction is " + "not available for the Function Estimator. " + "Use gp_type='cholesky' or gp_type='full' instead." + ) + logger.error(message) + raise ValueError(message) def __call__(self, x=None, y=None): """This calls self.fit_predict(x, y): @@ -143,6 +189,8 @@ def prepare_inference(self, x): :rtype: function, array-like """ x = self.set_x(x) + self._prepare_attribute("n_landmarks") + self._prepare_attribute("gp_type") if self.ls is None: self._prepare_attribute("nn_distances") self._prepare_attribute("ls") @@ -182,15 +230,22 @@ def compute_conditional(self, x=None, y=None): cov_func = self.cov_func sigma = self.sigma jitter = self.jitter - conditional = compute_conditional_mean( + y_is_mean = self.y_is_mean + with_uncertainty = self.predictor_with_uncertainty + conditional = compute_conditional( x, landmarks, None, + None, y, mu, cov_func, + None, + None, sigma, jitter=jitter, + y_is_mean=y_is_mean, + with_uncertainty=with_uncertainty, ) self.conditional = conditional return conditional @@ -227,8 +282,8 @@ def fit(self, x=None, y=None): # Check if the number of samples in x and y match if y.shape[0] != n_samples: raise ValueError( - f"X.shape[0] = {n_samples} (n_samples) should equal " - "y.shape[0] = {y.shape[0]}." + f"X.shape[0] = {n_samples:,} (n_samples) should equal " + f"y.shape[0] = {y.shape[0]:,}." ) self.prepare_inference(x) diff --git a/mellon/inference.py b/mellon/inference.py index 653e39a..dd0ca88 100644 --- a/mellon/inference.py +++ b/mellon/inference.py @@ -1,7 +1,7 @@ from collections import namedtuple from functools import partial from jax import random, vmap -from jax.numpy import log, pi, exp, stack, arange, sort, mean, zeros_like +from jax.numpy import log, pi, exp, stack, arange, sort, mean, zeros_like, any from jax.numpy import sum as arraysum from jax.scipy.special import gammaln import jax.scipy.stats.norm as norm @@ -9,14 +9,17 @@ from jax.example_libraries.optimizers import adam from jaxopt import ScipyMinimize from .conditional import ( - FullConditionalMean, - FullConditionalMeanTime, - LandmarksConditionalMean, - LandmarksConditionalMeanTime, - LandmarksConditionalMeanCholesky, - LandmarksConditionalMeanCholeskyTime, + FullConditional, + ExpFullConditional, + FullConditionalTime, + LandmarksConditional, + ExpLandmarksConditional, + LandmarksConditionalTime, + LandmarksConditionalCholesky, + ExpLandmarksConditionalCholesky, + LandmarksConditionalCholeskyTime, ) -from .util import ensure_2d, Exp, DEFAULT_JITTER +from .util import ensure_2d, DEFAULT_JITTER DEFAULT_N_ITER = 100 @@ -298,15 +301,38 @@ def compute_log_density_x(pre_transformation, transform): return transform(pre_transformation) -def compute_conditional_mean( +def compute_parameter_cov_factor(pre_transformation_std, L): + R""" + Computes :math:`\Sigma_L` the left factor of the covariance matrix + of `log_density_x` the mean function of the Gaussian Process on the + training data points. The uncertainty of the mean function comes from + the uncertainty of the inferred model parameters quantified by + `pre_transformation_std`. + + :param pre_transformation_std: Standard deviation of the parameters, e.g., as inferred by ADVI. + :type pre_transformation_std: array-like + :param L: A matrix such that :math:`L L^\top \approx K`, where :math:`K` is the + covariance matrix of the Gaussian Process. + :type L: array-like + :return: sigma_L - The left factor of the covariance matrix of the transformed parameters. + """ + return L * pre_transformation_std[None, :] + + +def compute_conditional( x, landmarks, pre_transformation, + pre_transformation_std, y, mu, cov_func, + L, + Lp=None, sigma=0, jitter=DEFAULT_JITTER, + y_is_mean=False, + with_uncertainty=False, ): R""" Builds the mean function of the Gaussian process, conditioned on the @@ -322,16 +348,32 @@ def compute_conditional_mean( Landmarks can be None if not using landmark points. pre_transformation : array-like or None The pre-transformed latent function representation. + pre_transformation_std : array-like or None + Standard deviation of the parameters, e.g., as inferred by ADVI. y : array-like The function values at each point in x. mu : float The original Gaussian process mean :math:`\mu`. cov_func : function The Gaussian process covariance function. + L : array-like + The matrix :math:`L` used to transform the latent function representation to + thr Gaussin Process mean. Typically :math:`L L^\top \approx K`, where :math:`K` is the + covariance matrix of the Gaussian Process. + Lp : array-like, optional + A matrix such that :math:`L_p L_p^\top = \Sigma_p`, where :math:`\Sigma_p` is the + covariance matrix of the Gaussian Process on the inducing points. sigma : float, optional White noise variance, by default 0. jitter : float, optional A small amount to add to the diagonal for stability, by default 1e-6. + y_is_mean : bool + Wether to consider y the GP mean or a noise measurment + subject to `sigma` or `y_cov_factor`. Has no effect if `L` is passed. + Defaults to False. + with_uncertainty : bool + Wether to compute covariance functions and predictive uncertainty. + Defaults to False. Returns ------- @@ -340,30 +382,54 @@ def compute_conditional_mean( """ if landmarks is None: - return FullConditionalMean( + if with_uncertainty and pre_transformation_std is not None: + y_cov_factor = compute_parameter_cov_factor(pre_transformation_std, L) + else: + y_cov_factor = None + return FullConditional( x, y, mu, cov_func, + Lp, sigma=sigma, jitter=jitter, + y_cov_factor=y_cov_factor, + y_is_mean=y_is_mean, + with_uncertainty=with_uncertainty, ) elif ( pre_transformation is not None and pre_transformation.shape[0] == landmarks.shape[0] ): landmarks = ensure_2d(landmarks) - return LandmarksConditionalMeanCholesky( + if pre_transformation_std is not None and sigma is not None and any(sigma > 0): + raise ValueError( + "One can specify either `sigma` or `pre_transformation_std` " + "to describe uncertainty, but not both." + ) + elif pre_transformation_std is not None: + sigma = pre_transformation_std + n_obs = x.shape[0] + return LandmarksConditionalCholesky( landmarks, pre_transformation, mu, cov_func, + n_obs, + Lp, sigma=sigma, jitter=jitter, + y_is_mean=y_is_mean, + with_uncertainty=with_uncertainty, ) else: landmarks = ensure_2d(landmarks) - return LandmarksConditionalMean( + if with_uncertainty and pre_transformation_std is not None: + y_cov_factor = compute_parameter_cov_factor(pre_transformation_std, L) + else: + y_cov_factor = None + return LandmarksConditional( x, landmarks, y, @@ -371,18 +437,26 @@ def compute_conditional_mean( cov_func, sigma=sigma, jitter=jitter, + y_cov_factor=y_cov_factor, + y_is_mean=y_is_mean, + with_uncertainty=with_uncertainty, ) -def compute_conditional_mean_times( +def compute_conditional_times( x, landmarks, pre_transformation, + pre_transformation_std, y, mu, cov_func, + L, + Lp, sigma=0, jitter=DEFAULT_JITTER, + y_is_mean=False, + with_uncertainty=False, ): R""" Builds the mean function of the Gaussian process, conditioned on the @@ -399,16 +473,31 @@ def compute_conditional_mean_times( Landmarks can be None if not using landmark points. pre_transformation : array-like or None The pre-transformed latent function representation. + pre_transformation_std : array-like or None + Standard deviation of the parameters, e.g., as inferred by ADVI. y : array-like The function values at each point in x. mu : float The original Gaussian process mean :math:`\mu`. cov_func : function The Gaussian process covariance function. + L : array-like + A matrix such that :math:`L L^\top \approx K`, where :math:`K` is the + covariance matrix of the Gaussian Process. + Lp : array-like, optional + A matrix such that :math:`L_p L_p^\top = \Sigma_p`, where :math:`\Sigma_p` is the + covariance matrix of the Gaussian Process on the inducing points. sigma : float, optional White noise variance, by default 0. jitter : float, optional A small amount to add to the diagonal for stability, by default 1e-6. + y_is_mean : bool + Wether to consider y the GP mean or a noise measurment + subject to `sigma` or `y_cov_factor`. Has no effect if `L` is passed. + Defaults to False. + with_uncertainty : bool + Wether to compute covariance functions and predictive uncertainty. + Defaults to False. Returns ------- @@ -418,30 +507,54 @@ def compute_conditional_mean_times( """ if landmarks is None: - return FullConditionalMeanTime( + if pre_transformation_std is not None: + y_cov_factor = compute_parameter_cov_factor(pre_transformation_std, L) + else: + y_cov_factor = None + return FullConditionalTime( x, y, mu, cov_func, + Lp, sigma=sigma, jitter=jitter, + y_cov_factor=y_cov_factor, + y_is_mean=y_is_mean, + with_uncertainty=with_uncertainty, ) elif ( pre_transformation is not None and pre_transformation.shape[0] == landmarks.shape[0] ): landmarks = ensure_2d(landmarks) - return LandmarksConditionalMeanCholeskyTime( + if pre_transformation_std is not None and sigma is not None and any(sigma > 0): + raise ValueError( + "One can specify either `sigma` or `pre_transformation_std` " + "to describe uncertainty, but not both." + ) + elif pre_transformation_std is not None: + sigma = pre_transformation_std + n_obs = x.shape[0] + return LandmarksConditionalCholeskyTime( landmarks, pre_transformation, mu, cov_func, + n_obs, + Lp, sigma=sigma, jitter=jitter, + y_is_mean=y_is_mean, + with_uncertainty=with_uncertainty, ) else: landmarks = ensure_2d(landmarks) - return LandmarksConditionalMeanTime( + if pre_transformation_std is not None: + y_cov_factor = compute_parameter_cov_factor(pre_transformation_std, L) + else: + y_cov_factor = None + return LandmarksConditionalTime( x, landmarks, y, @@ -449,64 +562,135 @@ def compute_conditional_mean_times( cov_func, sigma=sigma, jitter=jitter, + y_cov_factor=y_cov_factor, + y_is_mean=y_is_mean, + with_uncertainty=with_uncertainty, ) -def compute_conditional_mean_explog( +def compute_conditional_explog( x, landmarks, + pre_transformation, + pre_transformation_std, y, mu, cov_func, + L, + Lp, sigma=0, jitter=DEFAULT_JITTER, + y_is_mean=False, + with_uncertainty=False, ): R""" - Builds the mean function of the Gaussian process, conditioned on the - function exponential values (e.g., dimensionality) on x. + Builds the exp-mean function of the Gaussian process, conditioned on the + function log-values (e.g., dimensionality) on x. Returns a function that is defined on the whole domain of x. - :param x: The training instances. - :type x: array-like - :param landmarks: The landmark points for fast sparse computation. + Parameters + ---------- + x : array-like + The training instances. + landmarks : array-like or None + The landmark points for fast sparse computation. Landmarks can be None if not using landmark points. - :type landmarks: array-like - :param y: The function values at each point in x. - :type y: array-like - :param mu: The original Gaussian process mean :math:`\mu`. - :type mu: float - :param cov_func: The Gaussian process covariance function. - :type cov_func: function - :param sigma: White moise veriance. Defaults to 0. - :type sigma: float - :param jitter: A small amount to add to the diagonal for stability. Defaults to 1e-6. - :type jitter: float - :return: conditional_mean - The conditioned Gaussian process mean function. - :rtype: function + pre_transformation : array-like or None + The pre-transformed latent function representation. + pre_transformation_std : array-like or None + Standard deviation of the parameters, e.g., as inferred by ADVI. + y : array-like + The function values at each point in x. + mu : float + The original Gaussian process mean :math:`\mu`. + cov_func : function + The Gaussian process covariance function. + L : array-like + The matrix :math:`L` used to transform the latent function representation to + thr Gaussin Process mean. Typically :math:`L L^\top \approx K`, where :math:`K` is the + covariance matrix of the Gaussian Process. + Lp : array-like, optional + A matrix such that :math:`L_p L_p^\top = \Sigma_p`, where :math:`\Sigma_p` is the + covariance matrix of the Gaussian Process on the inducing points. + sigma : float, optional + White noise variance, by default 0. + jitter : float, optional + A small amount to add to the diagonal for stability, by default 1e-6. + y_is_mean : bool + Wether to consider y the GP mean or a noise measurment + subject to `sigma` or `y_cov_factor`. Has no effect if `L` is passed. + Defaults to False. + with_uncertainty : bool + Wether to compute covariance functions and predictive uncertainty. + Defaults to False. + + Returns + ------- + mellon.Predictor + The conditioned Gaussian process mean function. """ + if landmarks is None: - return Exp( - FullConditionalMean( - x, - log(y), - mu, - cov_func, - sigma=sigma, - jitter=jitter, + if with_uncertainty and pre_transformation_std is not None: + y_cov_factor = compute_parameter_cov_factor(pre_transformation_std, L) + else: + y_cov_factor = None + y = log(y) + return ExpFullConditional( + x, + y, + mu, + cov_func, + Lp, + sigma=sigma, + jitter=jitter, + y_cov_factor=y_cov_factor, + y_is_mean=y_is_mean, + with_uncertainty=with_uncertainty, + ) + elif ( + pre_transformation is not None + and pre_transformation.shape[0] == landmarks.shape[0] + ): + landmarks = ensure_2d(landmarks) + if pre_transformation_std is not None and sigma is not None and any(sigma > 0): + raise ValueError( + "One can specify either `sigma` or `pre_transformation_std` " + "to describe uncertainty, but not both." ) + elif pre_transformation_std is not None: + sigma = pre_transformation_std + n_obs = x.shape[0] + return ExpLandmarksConditionalCholesky( + landmarks, + pre_transformation, + mu, + cov_func, + n_obs, + Lp, + sigma=sigma, + jitter=jitter, + y_is_mean=y_is_mean, + with_uncertainty=with_uncertainty, ) else: landmarks = ensure_2d(landmarks) - return Exp( - LandmarksConditionalMean( - x, - landmarks, - log(y), - mu, - cov_func, - sigma=sigma, - jitter=jitter, - ) + if with_uncertainty and pre_transformation_std is not None: + y_cov_factor = compute_parameter_cov_factor(pre_transformation_std, L) + else: + y_cov_factor = None + y = log(y) + return ExpLandmarksConditional( + x, + landmarks, + y, + mu, + cov_func, + sigma=sigma, + jitter=jitter, + y_cov_factor=y_cov_factor, + y_is_mean=y_is_mean, + with_uncertainty=with_uncertainty, ) @@ -582,7 +766,7 @@ def run_advi( :type init_learn_rate: float :return: parameters, standard deviations, and loss after the optimization :return: Results - A named tuple containing pre_transformation, - pre_transform_std, losses: The optimized parameters, the optimized + pre_transformation_std, losses: The optimized parameters, the optimized standard deviations, and a history of ELBO values. :rtype: array-like, array-like, Object """ @@ -614,7 +798,8 @@ def update(i, opt_state): opt_state, elbo = update(t, opt_state) elbos.append(elbo.item()) - params, stds = get_params(opt_state) + params, log_stds = get_params(opt_state) + stds = exp(log_stds) Results = namedtuple("Results", "pre_transformation pre_transformation_std losses") return Results(params, stds, elbos) diff --git a/mellon/meta.yaml b/mellon/meta.yaml deleted file mode 100644 index 59e96c6..0000000 --- a/mellon/meta.yaml +++ /dev/null @@ -1,44 +0,0 @@ -{% set name = "mellon" %} -{% set version = "1.2.0" %} - -package: - name: {{ name|lower }} - version: {{ version }} - -source: - url: https://pypi.io/packages/source/{{ name[0] }}/{{ name }}/{{ name }}-{{ version }}.tar.gz - sha256: b36b45ec9034c868c15210d9e497b3b6b55bdda4ee418dd1bfea0301b5fe0ef8 - -build: - noarch: python - script: {{ PYTHON }} -m pip install . -vv - number: 0 - -requirements: - host: - - python >=3.6 - - pip - run: - - python >=3.6 - - jax - - jaxopt - - scikit-learn - -test: - imports: - - mellon - commands: - - pip check - requires: - - pip - -about: - home: https://github.com/settylab/mellon - summary: Non-parametric density estimator. - license: GPL-3.0-or-later - license_file: LICENSE - -extra: - recipe-maintainers: - - katosh - - ManuSetty diff --git a/mellon/parameter_validation.py b/mellon/parameter_validation.py new file mode 100644 index 0000000..958262f --- /dev/null +++ b/mellon/parameter_validation.py @@ -0,0 +1,279 @@ +import logging +from jax.numpy import ndarray +from .util import GaussianProcessType +from .base_cov import Covariance +from .validation import ( + _validate_positive_int, + _validate_float_or_int, +) + +logger = logging.getLogger("mellon") + + +def _validate_landmark_params(n_landmarks, landmarks): + """ + Validates that n_landmarks and landmarks are compatible. + + Parameters + ---------- + n_landmarks : int + Number of landmarks used in the approximation process. + landmarks : array-like or None + The given landmarks/inducing points. + """ + if landmarks is not None and n_landmarks != landmarks.shape[0]: + n_spec = landmarks.shape[0] + message = ( + f"There are {n_spec:,} landmarks specified but n_landmarks={n_landmarks:,}. " + "Please omit specifying n_landmarks if landmarks are given." + ) + logger.error(message) + raise ValueError(message) + + +def _validate_rank_params(gp_type, n_samples, rank, n_landmarks): + """ + Validates that rank, n_landmarks, and gp_type are compatible. + + Parameters + ---------- + gp_type : GaussianProcessType + The type of the Gaussian Process. It helps to decide the rank. + n_samples : int + The number of samples/cells. + rank : int or float or None + The rank of the approximate covariance matrix. If `None`, it will + be inferred based on the Gaussian Process type. If integer and greater + than or equal to the number of samples or if float and + equal to 1.0 or if 0, full rank is indicated. For FULL_NYSTROEM and + SPARSE_NYSTROEM, it should be fractional 0 < rank < 1 or integer 0 < rank < n. + n_landmarks : int + Number of landmarks used in the approximation process. + """ + if ( + type(rank) is int + and ( + (gp_type == GaussianProcessType.SPARSE_CHOLESKY and rank >= n_landmarks) + or (gp_type == GaussianProcessType.SPARSE_NYSTROEM and rank >= n_landmarks) + or (gp_type == GaussianProcessType.FULL and rank >= n_samples) + or (gp_type == GaussianProcessType.FULL_NYSTROEM and rank >= n_samples) + ) + or type(rank) is float + and rank >= 1.0 + or rank == 0 + ): + # full rank is indicated + if gp_type == GaussianProcessType.FULL_NYSTROEM: + message = ( + f"Gaussian Process type {gp_type} requires " + "fractional 0 < rank < 1 or integer " + f"0 < rank < {n_samples:,} (number of cells) " + f"but the actual rank is {rank}." + ) + logger.error(message) + raise ValueError(message) + elif gp_type == GaussianProcessType.SPARSE_NYSTROEM: + message = ( + f"Gaussian Process type {gp_type} requires " + "fractional 0 < rank < 1 or integer " + f"0 < rank < {n_landmarks:,} (number of landmakrs) " + f"but the actual rank is {rank}." + ) + logger.error(message) + raise ValueError(message) + elif ( + gp_type != GaussianProcessType.FULL_NYSTROEM + and gp_type != GaussianProcessType.SPARSE_NYSTROEM + ): + message = ( + f"Given rank {rank} indicates Nyström rank reduction. " + f"But the Gaussian Process type is set to {gp_type}." + ) + logger.error(message) + raise ValueError(message) + + +def _validate_gp_type(gp_type, n_samples, n_landmarks): + """ + Validates that gp_type, n_samples, and n_landmarks are compatible. + + Parameters + ---------- + gp_type : GaussianProcessType + The type of the Gaussian Process. It helps to decide the rank. + n_samples : int + The number of samples/cells. + n_landmarks : int + Number of landmarks used in the approximation process. + """ + # Validation logic for FULL and FULL_NYSTROEM types + if ( + ( + gp_type == GaussianProcessType.FULL + or gp_type == GaussianProcessType.FULL_NYSTROEM + ) + and n_landmarks != 0 + and n_landmarks < n_samples + ): + message = ( + f"Gaussian Process type {gp_type} but n_landmarks={n_landmarks:,} is smaller " + f"than the number of cells {n_samples:,}. Omit n_landmarks or set it to 0 to use " + "a non-sparse Gaussian Process or omit gp_type to use a sparse one." + ) + logger.error(message) + raise ValueError(message) + + # Validation logic for SPARSE_CHOLESKY and SPARSE_NYSTROEM types + elif ( + gp_type == GaussianProcessType.SPARSE_CHOLESKY + or gp_type == GaussianProcessType.SPARSE_NYSTROEM + ): + if n_landmarks == 0: + message = ( + f"Gaussian Process type {gp_type} but n_landmarks=0. Set n_landmarks " + f"to a number smaller than the number of cells {n_samples:,} to use a" + "sparse Gaussuian Process or omit gp_type to use a non-sparse one." + ) + logger.error(message) + raise ValueError(message) + elif n_landmarks >= n_samples: + message = ( + f"Gaussian Process type {gp_type} but n_landmarks={n_landmarks:,} is larger or " + f"equal the number of cells {n_samples:,}. Reduce the number of landmarks to use a" + "sparse Gaussuian Process or omit gp_type to use a non-sparse one.." + ) + logger.error(message) + raise ValueError(message) + + +def _validate_params(rank, gp_type, n_samples, n_landmarks, landmarks): + """ + Validates that rank, gp_type, n_samples, n_landmarks, and landmarks are compatible. + + Parameters + ---------- + rank : int or float or None + The rank of the approximate covariance matrix. If `None`, it will + be inferred based on the Gaussian Process type. If integer and greater + than or equal to the number of samples or if float and + equal to 1.0 or if 0, full rank is indicated. For FULL_NYSTROEM and + SPARSE_NYSTROEM, it should be fractional 0 < rank < 1 or integer 0 < rank < n. + gp_type : GaussianProcessType + The type of the Gaussian Process. It helps to decide the rank. + n_samples : int + The number of samples/cells. + n_landmarks : int + Number of landmarks used in the approximation process. + landmarks : array-like or None + The given landmarks/inducing points. + """ + + n_landmarks = _validate_positive_int(n_landmarks, "n_landmarks") + rank = _validate_float_or_int(rank, "rank") + + if not isinstance(gp_type, GaussianProcessType): + message = ( + "gp_type needs to be a mellon.util.GaussianProcessType but is a " + f"{type(gp_type)} instead." + ) + logger.error(message) + raise ValueError(message) + + # Validation logic for landmarks + _validate_landmark_params(n_landmarks, landmarks) + if n_landmarks > n_samples: + logger.warning( + f"n_landmarks={n_landmarks:,} is larger than the number of cells {n_samples:,}." + ) + + _validate_gp_type(gp_type, n_samples, n_landmarks) + + # Validation logic for rank + _validate_rank_params(gp_type, n_samples, rank, n_landmarks) + + +def _validate_cov_func_curry(cov_func_curry, cov_func, param_name): + """ + Validates covariance function curry type. + + Parameters + ---------- + cov_func_curry : type or None + The covariance function curry type to be validated. + cov_func : mellon.Covariance or None + An instance of covariance function. + param_name : str + The name of the parameter to be used in the error message. + + Returns + ------- + type + The validated covariance function curry type. + + Raises + ------ + ValueError + If both 'cov_func_curry' and 'cov_func' are None, or if 'cov_func_curry' is not a subclass of mellon.Covariance. + """ + + if cov_func_curry is None and cov_func is None: + raise ValueError( + "At least one of 'cov_func_curry' and 'cov_func' must not be None" + ) + + if cov_func_curry is not None: + if not isinstance(cov_func_curry, type) or not issubclass( + cov_func_curry, Covariance + ): + raise ValueError(f"'{param_name}' must be a subclass of mellon.Covariance") + return cov_func_curry + + +def _validate_cov_func(cov_func, param_name, optional=False): + """ + Validates an instance of a covariance function. + + Parameters + ---------- + cov_func : mellon.Covariance or None + The covariance function instance to be validated. + param_name : str + The name of the parameter to be used in the error message. + optional : bool, optional + Whether the value is optional. If optional and value is None, returns None. Default is False. + + Returns + ------- + mellon.Covariance or None + The validated instance of a subclass of mellon.Covariance or None if optional. + + Raises + ------ + ValueError + If 'cov_func' is not an instance of a subclass of mellon.Covariance and not None when not optional. + """ + + if cov_func is None and optional: + return None + + if not isinstance(cov_func, Covariance): + raise ValueError( + f"'{param_name}' must be an instance of a subclass of mellon.Covariance" + ) + return cov_func + + +def _validate_normalize_parameter(normalize, unique_times): + """ + Used in parameters.compute_nn_distances_within_time_points to validate input. + """ + if isinstance(normalize, dict): + missing_times = [t for t in unique_times if t.item() not in normalize] + if missing_times: + raise ValueError( + f"Missing time point(s) in normalization dictionary: {missing_times}" + ) + elif isinstance(normalize, (list, ndarray)) and len(normalize) != len(unique_times): + raise ValueError( + "Length of the normalize list or array must match the number of unique time points." + ) diff --git a/mellon/parameters.py b/mellon/parameters.py index 1973a00..296f9cf 100644 --- a/mellon/parameters.py +++ b/mellon/parameters.py @@ -1,25 +1,209 @@ -from jax.numpy import exp, log, quantile, stack, unique, empty +import logging +from jax.numpy import ( + exp, + log, + quantile, + stack, + unique, + empty, + where, + ndim, + asarray, + ndarray, +) from jax.numpy import sum as arraysum +from jax.numpy import any as arrayany +from jax.numpy import all as arrayall +from jax.numpy import min as arraymin from jax import random from sklearn.cluster import k_means from sklearn.linear_model import Ridge from sklearn.neighbors import BallTree, KDTree -from .util import mle, local_dimensionality, Log, ensure_2d, DEFAULT_JITTER +from .util import ( + mle, + local_dimensionality, + ensure_2d, + DEFAULT_JITTER, + GaussianProcessType, +) from .decomposition import ( - _check_method, _full_rank, _full_decomposition_low_rank, _standard_low_rank, _modified_low_rank, DEFAULT_RANK, - DEFAULT_METHOD, + DEFAULT_SIGMA, +) +from .validation import ( + _validate_time_x, + _validate_positive_float, + _validate_float_or_int, + _validate_positive_int, + _validate_float_or_iterable_numerical, +) +from .parameter_validation import ( + _validate_params, + _validate_normalize_parameter, ) -from .validation import _validate_time_x, _validate_positive_float DEFAULT_N_LANDMARKS = 5000 -logger = Log() +logger = logging.getLogger("mellon") + + +def compute_rank(gp_type): + """ + Compute the appropriate rank reduction based on the given Gaussian Process type. + + Parameters + ---------- + gp_type : GaussianProcessType + The type of the Gaussian Process. It helps to decide the rank. + + Returns + ------- + computed_rank : float or int or None + The computed rank value based on the `gp_type`, `rank`, and shape of `x`. + + Raises + ------ + ValueError + If the given rank and Gaussian Process type conflict with each other. + """ + + if gp_type is None: + return 1.0 + elif (gp_type == GaussianProcessType.FULL_NYSTROEM) or ( + gp_type == GaussianProcessType.SPARSE_NYSTROEM + ): + return DEFAULT_RANK + else: + return 1.0 + + +def compute_n_landmarks(gp_type, n_samples, landmarks): + """ + Compute the number of landmarks based on the given Gaussian Process type and landmarks. + + Parameters + ---------- + gp_type : GaussianProcessType + The type of the Gaussian Process. It helps to decide the number of landmarks. + n_samples : array-like + The number of samples/cells. + landmarks : array-like or None + The given landmarks. If specified, its shape determines the number of landmarks, + unless conflicting with `n_landmarks`. + + Returns + ------- + computed_n_landmarks : int + The computed number of landmarks based on the `gp_type`, `n_landmarks`, shape of `x`, and `landmarks`. + + Raises + ------ + ValueError + If the given number of landmarks, Gaussian Process type, and landmarks conflict with each other. + + """ + if landmarks is not None: + return landmarks.shape[0] + + if gp_type is None: + n_landmarks = min(n_samples, DEFAULT_N_LANDMARKS) + elif ( + gp_type == GaussianProcessType.FULL + or gp_type == GaussianProcessType.FULL_NYSTROEM + ): + n_landmarks = n_samples + elif ( + gp_type == GaussianProcessType.SPARSE_CHOLESKY + or gp_type == GaussianProcessType.SPARSE_NYSTROEM + ): + if n_samples <= DEFAULT_N_LANDMARKS: + message = ( + f"Gaussian Process type {gp_type} and default " + f"number of landmarks {DEFAULT_N_LANDMARKS:,} < " + f"number of cells {n_samples:,}. Reduce n_landmarks below " + f"the number of cells to use {gp_type}." + ) + logger.warning(message) + n_landmarks = DEFAULT_N_LANDMARKS + else: + n_landmarks = min(n_samples, DEFAULT_N_LANDMARKS) + logger.warning( + f"Unknown Gaussian Process type {gp_type}, using default " + f"n_landmarks={n_landmarks:,}." + ) + return n_landmarks + + +def compute_gp_type(n_landmarks, rank, n_samples): + """ + Determines the type of Gaussian Process based on the landmarks, rank, and method values. + + Parameters + ---------- + landmarks : array-like or None + The landmark points for sparse computation. + rank : int or float + The rank of the approximate covariance matrix. + n_samples : array-like + The number of samples/cells. + + Returns + ------- + GaussianProcessType + One of the Gaussian Process types defined in the `GaussianProcessType` Enum. + """ + + rank = _validate_float_or_int(rank, "rank", optional=True) + n_landmarks = _validate_positive_int(n_landmarks, "n_landmarks") + n_samples = _validate_positive_int(n_samples, "n_samples") + + if n_landmarks == 0 or n_landmarks >= n_samples: + # Full model + if ( + rank is None + or type(rank) is int + and (rank >= n_samples) + or type(rank) is float + and rank >= 1.0 + or rank == 0 + ): + logger.info( + "Using non-sparse Gaussian Process since n_landmarks " + f"({n_landmarks:,}) >= n_samples ({n_samples:,}) and rank = {rank}." + ) + return GaussianProcessType.FULL + else: + logger.info( + "Using full Gaussian Process with Nyström rank reduction since n_landmarks " + f"({n_landmarks:,}) >= n_samples ({n_samples:,}) and rank = {rank}." + ) + return GaussianProcessType.FULL_NYSTROEM + else: + # Sparse model + if ( + rank is None + or type(rank) is int + and (rank >= n_landmarks) + or type(rank) is float + and rank >= 1.0 + or rank == 0 + ): + logger.info( + "Using sparse Gaussian Process since n_landmarks " + f"({n_landmarks:,}) < n_samples ({n_samples:,}) and rank = {rank}." + ) + return GaussianProcessType.SPARSE_CHOLESKY + else: + logger.info( + "Using sparse Gaussian Process with improved Nyström rank reduction since n_landmarks " + f"({n_landmarks:,}) >= n_samples ({n_samples:,}) and rank = {rank}." + ) + return GaussianProcessType.SPARSE_NYSTROEM def compute_landmarks(x, n_landmarks=DEFAULT_N_LANDMARKS): @@ -133,17 +317,19 @@ def compute_distances(x, k): return distances -def compute_nn_distances(x): +def compute_nn_distances(x, save=True): """ Compute the distance to the nearest neighbor for each instance in the provided training dataset. This function calculates the Euclidean distance between each instance in the dataset and its closest neighbor. - It returns an array of these distances, ordered in the same way as the input instances. + If save=True, any non-positive distances will be replaced with the minimum positive distance. Parameters ---------- x : array-like of shape (n_samples, n_features) An array-like object representing the training instances. + save : bool, optional + Whether to replace non-positive distances with the minimum positive distance. Default is True. Returns ------- @@ -152,11 +338,37 @@ def compute_nn_distances(x): the input dataset. The ordering of the distances in this array corresponds to the ordering of the instances in the input data. + Raises + ------ + ValueError : if all distances are non-positive and save=True. """ - return compute_distances(x, 1)[:, 0] + nn_distances = compute_distances(x, 1)[:, 0] + + if save and arrayany(nn_distances <= 0): + good_idx = nn_distances > 0 + if arrayall(~good_idx): + message = "All instances seem to be identical." + logger.error(message) + raise ValueError(message) + min_positive = arraymin(nn_distances[good_idx]) + n_identical = arraysum(~good_idx) + logger.warning( + f"Found {n_identical:,} identical cells. Adding {min_positive} to their pairwise distance." + ) + nn_distances = where(good_idx, nn_distances, min_positive) + + return nn_distances + +def _get_target_cell_count(normalize, time, av_cells_per_tp, unique_times): + if isinstance(normalize, bool): + return av_cells_per_tp + if isinstance(normalize, dict): + return normalize[time.item()] + return normalize[unique_times.tolist().index(time)] -def compute_nn_distances_within_time_points(x, times=None, normalize=False): + +def compute_nn_distances_within_time_points(x, times=None, d=None, normalize=False): R""" Computes the distance to the nearest neighbor for each training instance within the same time point group. It retains the original order of instances in `x`. @@ -173,10 +385,24 @@ def compute_nn_distances_within_time_points(x, times=None, normalize=False): If provided, it overrides the last column of 'x' as the times. Shape must be either (n_samples,) or (n_samples, 1). - normalize : bool, optional - If True, distances are normalized by the number of samples within the same time point group. - This normalization reduces potential bias in the density estimation arising from uneven - sampling across different time points. Defaults to False. + d : int, array-like or None + The intrinsic dimensionality of the data, i.e., the dimensionality of the embedded + manifold. Only required for the normalization. + Defaults to None. + + normalize : bool, list, array-like, or dict, optional + Controls the normalization for varying cell counts across time points to adjust for sampling bias + by modifying the nearest neighbor distances. + + - If True, normalizes to simulate a constant total cell count divided by the number of time points. + + - If False, the raw cell counts per time point is reflected in the nearest neighbor distances. + + - If a list or array-like, assumes total cell counts for time points, ordered from earliest to latest. + + - If a dict, maps each time point to its total cell count. Must cover all unique time points. + + Default is False. Returns ------- @@ -191,6 +417,21 @@ def compute_nn_distances_within_time_points(x, times=None, normalize=False): n_cells = x.shape[0] av_cells_per_tp = n_cells / len(unique_times) + _validate_normalize_parameter(normalize, unique_times) + + if normalize is not False and normalize is not None: + d = _validate_float_or_iterable_numerical(d, "d", optional=False, positive=True) + if ndim(d) > 0 and len(d) != x.shape[0]: + ld = len(d) + raise ValueError( + f"If `d` (length={ld:,}) is a vector then it needs to have one value " + f"per cell in x (x.shape[0]={n_cells:,})." + ) + logger.info( + "Normalizing nearest neighbor distances correcting sampling bias for " + f"{len(unique_times):,} different time points." + ) + for time in unique_times: mask = x[:, -1] == time n_samples = arraysum(mask) @@ -199,13 +440,19 @@ def compute_nn_distances_within_time_points(x, times=None, normalize=False): f"Insufficient data: Only {n_samples} sample(s) found at time point {time}. " "Nearest neighbors cannot be computed with less than two samples per time point. " "Please confirm if you have provided the correct time axis. " - "If the time points indeed have very few samples, consider aggregating nearby time points for better results, " - "or you may specify `nn_distances` manually." + "If the time points indeed have very few samples, consider aggregating nearby " + "time points for better results, or you may specify `nn_distances` manually." ) x_at_time = x[mask, :-1] nn_distances_at_time = compute_nn_distances(x_at_time) - if normalize: - nn_distances_at_time = nn_distances_at_time * n_samples / av_cells_per_tp + if normalize is not False and normalize is not None: + target_cell_count = _get_target_cell_count( + normalize, time, av_cells_per_tp, unique_times + ) + factor = (n_samples / target_cell_count) ** ( + 1 / d if ndim(d) == 0 else 1 / d[mask] + ) + nn_distances_at_time = factor * nn_distances_at_time nn_distances = nn_distances.at[mask].set(nn_distances_at_time) return nn_distances @@ -224,16 +471,33 @@ def compute_d(x): def compute_d_factal(x, k=10, n=500, seed=432): - R""" + """ Computes the dimensionality of the data based on the average fractal - dimension around n randomly selected cells. + dimension around `n` randomly selected cells. + + Parameters + ---------- + x : array-like + The training instances. Shape must be (n_samples, n_features). + k : int, optional + Number of nearest neighbors to use in the algorithm. + Defaults to 10. + n : int, optional + Number of samples to randomly select. + Defaults to 500. + seed : int, optional + Random seed for sampling. Defaults to 432. + + Returns + ------- + float + The average fractal dimension of the data. + + Warnings + -------- + If `k` is greater than the number of samples in `x`, a warning will + be logged, and `k` will be set to the number of samples. - :param x: The training instances. - :type x: array-like - :param n: Number of samples. - :type n: int - :param seed: Random seed for sampling. - :type seed: int """ if len(x.shape) < 2: return 1 @@ -244,7 +508,7 @@ def compute_d_factal(x, k=10, n=500, seed=432): else: x_query = x local_dims = local_dimensionality(x, k=k, x_query=x_query) - return local_dims.mean() + return local_dims.mean().item() def compute_mu(nn_distances, d): @@ -265,13 +529,15 @@ def compute_mu(nn_distances, d): def compute_ls(nn_distances): R""" - Computes ls equal to the geometric mean of the nearest neighbor distances times a constant. + Computes a length scale (ls) equal to the geometric mean of the positive nearest neighbor distances + times a constant. - :param nn_distances: The observed nearest neighbor distances. + :param nn_distances: The observed nearest neighbor distances. Must be non-empty. :type nn_distances: array-like - :return: ls - The geometric mean of the nearest neighbor distances times a constant. + :return: ls - The geometric mean of the nearest neighbor distances (after adjustment) times a constant. :rtype: float """ + return exp(log(nn_distances).mean() + 3.0).item() @@ -301,95 +567,224 @@ def compute_cov_func(cov_func_curry, ls, ls_time=None): Otherwise, it's a single covariance function considering only the feature dimensions. """ if ls_time is not None: - return cov_func_curry(ls, active_dims=slice(None, -1)) * cov_func_curry( - ls_time, active_dims=-1 + return cov_func_curry(ls=ls, active_dims=slice(None, -1)) * cov_func_curry( + ls=ls_time, active_dims=-1 ) - return cov_func_curry(ls) + return cov_func_curry(ls=ls) + + +def compute_Lp( + x, + cov_func, + gp_type=None, + landmarks=None, + sigma=DEFAULT_SIGMA, + jitter=DEFAULT_JITTER, +): + R""" + Compute a matrix :math:`L_p` such that :math:`L_p L_p^\top = \Sigma_p` + where :math:`\Sigma_p` is the full rank covariance matrix on the + inducing points. Unless a full Nyström method or sparse Nyström method + is used, in which case None is returned. + + Parameters + ---------- + x : array-like + The training instances. + cov_func : function + The Gaussian process covariance function. + gp_type : str or GaussianProcessType + The type of sparcification used for the Gaussian Process: + - 'full' None-sparse Gaussian Process + - 'sparse_cholesky' Sparse GP using landmarks/inducing points, + typically employed to enable scalable GP models. + landmarks : array-like + The landmark points. + sigma : float, optional + Noise standard deviation of the data we condition on. Defaults to 0. + jitter : float, optional + A small amount to add to the diagonal. Defaults to 1e-6. + + Returns + ------- + array-like or None + :math:`L_p` - A matrix such that :math:`L_p L_p^\top = \Sigma_p`, + or None if using full or sparse Nyström. + """ + x = ensure_2d(x) + n_samples = x.shape[0] + if landmarks is None: + n_landmarks = n_samples + landmarks = x + else: + n_landmarks = landmarks.shape[0] + gp_type = GaussianProcessType.from_string(gp_type, optional=True) + if gp_type is None: + gp_type = compute_gp_type(n_landmarks, 1.0, n_samples) + + if ( + gp_type == GaussianProcessType.FULL_NYSTROEM + or gp_type == GaussianProcessType.SPARSE_NYSTROEM + ): + return None + elif gp_type == GaussianProcessType.FULL: + return _full_rank(x, cov_func, sigma=sigma, jitter=jitter) + elif gp_type == GaussianProcessType.SPARSE_CHOLESKY: + return _full_rank(landmarks, cov_func, sigma=sigma, jitter=jitter) + else: + message = f"Unknown Gaussian Process type {gp_type}." + logger.error(message) + raise ValueError(message) + + +def _validate_compute_L_input(x, cov_func, gp_type, landmarks, Lp, rank, sigma, jitter): + """ + Validate input for the fuction compute_L. + + Returns + ------- + x : array-like + The training instances with at least 2 dimensions (n_samples, n_dims). + n_landmarks : int + The number of landmarks. + n_samples : int + The number of samples/cells. + gp_type : mellon.util.GaussianProcessType + The type of Gaussian Process to use. + rank : float, optional + The rank of the approximate covariance matrix. + + Raises + ------ + ValueError + If any of the inputs are inconsistent or violate constraints. + """ + jitter = _validate_positive_float(jitter, "jitter") + rank = _validate_float_or_int(rank, "rank", optional=True) + + n_samples = x.shape[0] + if landmarks is None: + n_landmarks = n_samples + else: + n_landmarks = landmarks.shape[0] + gp_type = GaussianProcessType.from_string(gp_type, optional=True) + if rank is None: + rank = compute_rank(gp_type) + if gp_type is None: + gp_type = compute_gp_type(n_landmarks, rank, n_samples) + _validate_params(rank, gp_type, n_samples, n_landmarks, landmarks) + + if ( + gp_type == GaussianProcessType.FULL + and Lp is not None + and Lp.shape != (n_samples, n_samples) + ): + message = ( + f" Wrong shape of Lp {Lp.shape} for {gp_type} and {n_samples:,} samples." + ) + logger.error(message) + raise ValueError(message) + elif ( + gp_type == GaussianProcessType.SPARSE_CHOLESKY + and Lp is not None + and Lp.shape != (n_landmarks, n_landmarks) + ): + message = f" Wrong shape of Lp {Lp.shape} for {gp_type} and {n_landmarks:,} landmarks." + logger.error(message) + raise ValueError(message) + + return x, n_landmarks, n_samples, gp_type, rank def compute_L( x, cov_func, + gp_type=None, landmarks=None, - rank=DEFAULT_RANK, - method=DEFAULT_METHOD, + Lp=None, + rank=None, + sigma=DEFAULT_SIGMA, jitter=DEFAULT_JITTER, ): R""" - Compute an :math:`L` such that :math:`L L^\top \approx K`, where - :math:`K` is the covariance matrix. + Compute a low rank :math:`L` such that :math:`L L^\top \approx K`, + where :math:`K` is the full rank covariance matrix on `x`. - :param x: The training instances. - :type x: array-like - :param cov_func: The Gaussian process covariance function. - :type cov_func: function - :param landmarks: The landmark points. If None, computes a full rank decompostion. - Defaults to None. - :type landmarks: array-like - :param rank: The rank of the approximate covariance matrix. + Parameters + ---------- + x : array-like + The training instances. + cov_func : function + The Gaussian process covariance function. + gp_type : str or GaussianProcessType + The type of sparcification used for the Gaussian Process: + - 'full' None-sparse Gaussian Process + - 'full_nystroem' Sparse GP with Nyström rank reduction without landmarks, + which lowers the computational complexity. + - 'sparse_cholesky' Sparse GP using landmarks/inducing points, + typically employed to enable scalable GP models. + - 'sparse_nystroem' Sparse GP using landmarks or inducing points, + along with an improved Nyström rank reduction method that balances + accuracy with efficiency. + landmarks : array-like, optional + The landmark points. If None, computes a full rank decomposition. Defaults to None. + rank : int or float, optional + The rank of the approximate covariance matrix. If rank is an int, an :math:`n \times` rank matrix :math:`L` is computed such that :math:`L L^\top \approx K`, the exact :math:`n \times n` covariance matrix. If rank is a float 0.0 :math:`\le` rank :math:`\le` 1.0, the rank/size of :math:`L` is selected such that the included eigenvalues of the covariance between landmark points account for the specified percentage of the - sum of eigenvalues. Defaults to 0.999. - :type rank: int or float - :param method: Explicitly specifies whether rank is to be interpreted as a - fixed number of eigenvectors or a percent of eigenvalues to include - in the low rank approximation. Supports 'fixed', 'percent', or 'auto'. - If 'auto', interprets rank as a fixed number of eigenvectors if it is - an int and interprets rank as a percent of eigenvalues if it is a float. - Defaults to 'auto'. - :type method: str - :param jitter: A small amount to add to the diagonal. Defaults to 1e-6. - :type jitter: float - :return: :math:`L` - A matrix such that :math:`L L^\top \approx K`. - :rtype: array-like - """ - x = ensure_2d(x) - n_samples = x.shape[0] - if landmarks is None: - n = x.shape[0] - method = _check_method(rank, n, method) - - if type(rank) is int and rank == n or type(rank) is float and rank == 1.0: - logger.info( - f"Doing full-rank Cholesky decomposition for {n_samples:,} samples." - ) - return _full_rank(x, cov_func, jitter=jitter) - else: - logger.info( - f"Doing full-rank singular value decomposition for {n_samples:,} samples." - ) - return _full_decomposition_low_rank( - x, cov_func, rank=rank, method=method, jitter=jitter - ) - else: - landmarks = ensure_2d(landmarks) + sum of eigenvalues. Defaults to 0.99 if gp_type indicates Nyström. + sigma : float, array-like, optional + Noise standard deviation of the data we condition on. Defaults to 0. + jitter : float, optional + A small amount to add to the diagonal. Defaults to 1e-6. + Lp : array-like, optional + Prespecified matrix :math:`L_p` sich that :math:`L_p L_p^\top = \Sigma_p` + where :math:`\Sigma_p` is the full rank covariance matrix on the + inducing points. Defaults to None. - n_landmarks = landmarks.shape[0] - method = _check_method(rank, n_landmarks, method) + Returns + ------- + array-like + :math:`L` - Matrix such that :math:`L L^\top \approx K`. - if ( - type(rank) is int - and rank == n_landmarks - or type(rank) is float - and rank == 1.0 - ): - logger.info( - "Doing low-rank Cholesky decomposition for " - f"{n_samples:,} samples and {n_landmarks:,} landmarks." - ) - return _standard_low_rank(x, cov_func, landmarks, jitter=jitter) - else: - logger.info( - "Doing low-rank improved Nyström decomposition for " - f"{n_samples:,} samples and {n_landmarks:,} landmarks." - ) - return _modified_low_rank( - x, cov_func, landmarks, rank=rank, method=method, jitter=jitter + Raises + ------ + ValueError + If the Gaussian Process type is unknown or if the shape of Lp is incorrect. + """ + x, n_landmarks, n_samples, gp_type, rank = _validate_compute_L_input( + x, cov_func, gp_type, landmarks, Lp, rank, sigma, jitter + ) + + if gp_type == GaussianProcessType.FULL: + if Lp is None: + return _full_rank(x, cov_func, sigma=sigma, jitter=jitter) + return Lp + elif gp_type == GaussianProcessType.FULL_NYSTROEM: + return _full_decomposition_low_rank( + x, cov_func, rank=rank, sigma=sigma, jitter=jitter + ) + elif gp_type == GaussianProcessType.SPARSE_CHOLESKY: + if Lp is None: + return _standard_low_rank( + x, cov_func, landmarks, sigma=sigma, jitter=jitter ) + return _standard_low_rank( + x, cov_func, landmarks, Lp=Lp, sigma=sigma, jitter=jitter + ) + elif gp_type == GaussianProcessType.SPARSE_NYSTROEM: + return _modified_low_rank( + x, + cov_func, + landmarks, + rank=rank, + sigma=sigma, + jitter=jitter, + ) def compute_initial_value(nn_distances, d, mu, L): @@ -438,3 +833,48 @@ def compute_initial_dimensionalities(x, mu_dim, mu_dens, L, nn_distances, d): initial_dens = compute_initial_value(nn_distances, d, mu_dens, L) initial_value = stack([initial_dims, initial_dens]) return initial_value + + +def compute_average_cell_count(x, normalize): + """ + Compute the average cell count based on the `normalize` parameter and the input data `x`. + + Parameters + ---------- + x : jax.numpy.ndarray + Input array with shape (n_samples, n_features). + The last column is assumed to contain the time identifiers. + + normalize : bool, list, jax.numpy.ndarray, dict, or None + The parameter controlling the normalization. + + - If True or None, returns the average cell count computed from `x`. + + - If a list or jax.numpy.ndarray, returns the average of the list or array. + + - If a dict, returns the average of the dict values. + + Returns + ------- + float + The average cell count computed based on the `normalize` parameter and `x`. + + Raises + ------ + ValueError + If the type of `normalize` is not recognized. + """ + n_cells = x.shape[0] + unique_times = unique(x[:, -1]) + n_unique_times = unique_times.shape[0] + + if normalize is None or isinstance(normalize, bool): + return n_cells / n_unique_times + + if isinstance(normalize, dict): + return sum(normalize.values()) / n_unique_times + + if isinstance(normalize, (list, ndarray)): + return arraysum(asarray(normalize)) / len(normalize) + + raise ValueError(f"Unrecognized type for 'normalize': {type(normalize)}") diff --git a/mellon/time_sensitive_density_estimator.py b/mellon/time_sensitive_density_estimator.py index 4c77ed4..6b08e47 100644 --- a/mellon/time_sensitive_density_estimator.py +++ b/mellon/time_sensitive_density_estimator.py @@ -1,10 +1,10 @@ -from .decomposition import DEFAULT_RANK, DEFAULT_METHOD +import logging from .base_model import BaseEstimator, DEFAULT_COV_FUNC from .inference import ( compute_transform, compute_loss_func, compute_log_density_x, - compute_conditional_mean_times, + compute_conditional_times, DEFAULT_N_ITER, DEFAULT_INIT_LEARN_RATE, DEFAULT_JIT, @@ -15,28 +15,28 @@ compute_landmarks_rescale_time, compute_cov_func, compute_d, + compute_ls, compute_d_factal, compute_mu, compute_initial_value, - DEFAULT_N_LANDMARKS, + compute_average_cell_count, ) from .compute_ls_time import compute_ls_time from .util import ( DEFAULT_JITTER, - Log, + object_str, ) from .validation import ( _validate_time_x, _validate_positive_float, _validate_string, _validate_array, - _validate_bool, ) DEFAULT_D_METHOD = "embedding" -logger = Log() +logger = logging.getLogger("mellon") class TimeSensitiveDensityEstimator(BaseEstimator): @@ -56,24 +56,34 @@ class TimeSensitiveDensityEstimator(BaseEstimator): Defaults to Matern52. n_landmarks : int - The number of landmark points. If less than 1 or greater than or equal to the - number of training points, inducing points will not be computed or used. - Defaults to 5000. + The number of landmark/inducing points. Only used if a sparse GP is indicated + through gp_type. If 0 or equal to the number of training points, inducing points + will not be computed or used. Defaults to 5000. rank : int or float - The rank of the approximate covariance matrix. If rank is an int, an :math:`n \times` + The rank of the approximate covariance matrix for the Nyström rank reduction. + If rank is an int, an :math:`n \times` rank matrix :math:`L` is computed such that :math:`L L^\top \approx K`, where `K` is the exact :math:`n \times n` covariance matrix. If rank is a float 0.0 :math:`\le` rank :math:`\le` 1.0, the rank/size of :math:`L` is selected such that the included eigenvalues of the covariance between landmark points account for the specified percentage of the sum - of eigenvalues. Defaults to 0.99. - - method : str - Determines how the rank is interpreted: as a fixed number of eigenvectors ('fixed'), a - percent of eigenvalues ('percent'), or automatically ('auto'). If 'auto', the rank is - interpreted as a fixed number of eigenvectors if it is an int and as a percent of - eigenvalues if it is a float. This parameter is provided for clarity in the ambiguous case - of 1 vs 1.0. Defaults to 'auto'. + of eigenvalues. It is ignored if gp_type does not indicate a Nyström rank reduction. + Defaults to 0.99. + + gp_type : str or GaussianProcessType + The type of sparcification used for the Gaussian Process: + - 'full' None-sparse Gaussian Process + - 'full_nystroem' Sparse GP with Nyström rank reduction without landmarks, + which lowers the computational complexity. + - 'sparse_cholesky' Sparse GP using landmarks/inducing points, + typically employed to enable scalable GP models. + - 'sparse_nystroem' Sparse GP using landmarks or inducing points, + along with an improved Nyström rank reduction method. + + The value can be either a string matching one of the above options or an instance of + the `mellon.util.GaussianProcessType` Enum. If a partial match is found with the + Enum, a warning will be logged, and the closest match will be used. + Defaults to 'sparse_cholesky'. d_method : str The method to compute the intrinsic dimensionality of the data. Implemented options are @@ -108,13 +118,24 @@ class TimeSensitiveDensityEstimator(BaseEstimator): distances are computed automatically, using a KDTree if the dimensionality of the data is less than 20, or a BallTree otherwise. Defaults to None. - normalize_per_time_point : bool, optional - If True, the computation of nearest neighbor distances incorporates a normalization step - based on the number of samples within each time point group. This process mitigates potential - bias in the density estimation due to unequal sample distribution across different time points. - Note that if `nn_distance` is provided, this parameter is ignored as the provided distances - are used directly. - Defaults to False. + normalize_per_time_point : bool, list, array-like, or dict, optional + Controls the normalization for varying cell counts across time points to adjust for sampling bias + by modifying the nearest neighbor distances before inference. + + - If True, normalizes to simulate a constant total cell count divided by the number of time points. + + - If False, the raw cell counts per time point is reflected in the density estimation. + + - If a list or array-like, assumes total cell counts for time points, ordered from earliest to latest. + + - If a dict, maps each time point to its total cell count. Must cover all unique time points. + + Note: Relative cell counts are sufficient for comparison within dataset; exact numbers are not required. + + Note: Ignored if `nn_distance` is provided; distances are used as-is and this parameter has no effect. + + Default is False. + d : int, array-like or None The intrinsic dimensionality of the data, i.e., the dimensionality of the embedded @@ -157,12 +178,16 @@ class TimeSensitiveDensityEstimator(BaseEstimator): Default is an empty dictionary ({}). cov_func : mellon.Covariance or None - The Gaussian process covariance function of the form k(x, y) :math:`\rightarrow` float. - Should be an instance of a class that inherits from :class:`mellon.Covariance`. + The Gaussian process covariance function as instance of :class:`mellon.Covaraince`. If None, the covariance function `cov_func` is automatically generated as `cov_func_curry(ls, active_dims=slice(None, -1)) * cov_func_curry(ls_time, active_dims=-1)`. Defaults to None. + Lp : array-like or None + A matrix such that :math:`L_p L_p^\top = \Sigma_p`, where :math:`\Sigma_p` is the + covariance matrix of the inducing points (all cells in non-sparse GP). + Not used when Nyström rank reduction is employed. Defaults to None. + L : array-like or None A matrix such that :math:`L L^\top \approx K`, where :math:`K` is the covariance matrix. If None, `L` is computed automatically. Defaults to None. @@ -173,23 +198,44 @@ class TimeSensitiveDensityEstimator(BaseEstimator): - (d/2) \cdot \log(\pi) - d \cdot \log(\text{nn_distances})` and :math:`d` is the intrinsic dimensionality of the data. Defaults to None. + predictor_with_uncertainty : bool + If set to True, computes the predictor instance `.predict` with its predictive uncertainty. + The uncertainty comes from two sources: + + 1) `.predict.mean_covariance`: + Uncertainty arising from the posterior distribution of the Bayesian inference. + This component quantifies uncertainties inherent in the model's parameters and structure. + Available only if `.pre_transformation_std` is defined (e.g., using `optimizer="advi"`), + which reflects the standard deviation of the latent variables before transformation. + + 2) `.predict.covariance`: + Uncertainty for out-of-bag states originating from the compressed function representation + in the Gaussian Process. Specifically, this uncertainty corresponds to locations that are + not inducing points of the Gaussian Process and represents the covariance of the + conditional normal distribution. + _save_intermediate_ls_times : bool Determines whether the intermediate results obtained during the computation of `ls_time` are retained for debugging. When set to True, the results will be stored in `self.densities`, `self.predictors`, and `self.numeric_stages`. Defaults to False. - jit : bool Use jax just-in-time compilation for loss and its gradient during optimization. Defaults to False. + + check_rank : bool + Weather to check if landmarks allow sufficient complexity by checking the approximate + rank of the covariance matrix. This only applies to the non-Nyström gp_types. + If set to None the rank check is only performed if n_landmarks >= n_samples/10. + Defaults to None. """ def __init__( self, cov_func_curry=DEFAULT_COV_FUNC, - n_landmarks=DEFAULT_N_LANDMARKS, - rank=DEFAULT_RANK, - method=DEFAULT_METHOD, + n_landmarks=None, + rank=None, + gp_type=None, d_method=DEFAULT_D_METHOD, jitter=DEFAULT_JITTER, optimizer=DEFAULT_OPTIMIZER, @@ -203,19 +249,23 @@ def __init__( ls=None, ls_time=None, ls_factor=1, - ls_factor_times=1, + ls_time_factor=1, density_estimator_kwargs=dict(), cov_func=None, + Lp=None, L=None, initial_value=None, + predictor_with_uncertainty=False, _save_intermediate_ls_times=False, jit=DEFAULT_JIT, + check_rank=None, ): super().__init__( cov_func_curry=cov_func_curry, n_landmarks=n_landmarks, rank=rank, jitter=jitter, + gp_type=gp_type, optimizer=optimizer, n_iter=n_iter, init_learn_rate=init_learn_rate, @@ -226,9 +276,12 @@ def __init__( ls=ls, ls_factor=ls_factor, cov_func=cov_func, + Lp=Lp, L=L, initial_value=initial_value, + predictor_with_uncertainty=predictor_with_uncertainty, jit=jit, + check_rank=check_rank, ) if not isinstance(density_estimator_kwargs, dict): raise ValueError("density_estimator_kwargs needs to be a dictionary.") @@ -237,55 +290,56 @@ def __init__( d_method, "d_method", choices={"fractal", "embedding"} ) self.ls_time = _validate_positive_float(ls_time, "ls_time", optional=True) - self.ls_factor_times = _validate_positive_float( - ls_factor_times, "ls_factor_times" - ) + self.ls_time_factor = _validate_positive_float(ls_time_factor, "ls_time_factor") self._save_intermediate_ls_times = _save_intermediate_ls_times - self.normalize_per_time_point = _validate_bool( - normalize_per_time_point, "normalize_per_time_point" - ) + self.normalize_per_time_point = normalize_per_time_point self.transform = None self.loss_func = None self.opt_state = None self.losses = None self.pre_transformation = None + self.pre_transformation_std = None self.log_density_x = None self.log_density_func = None def __repr__(self): name = self.__class__.__name__ + landmarks = object_str(self.landmarks, ["landmarks", "dims"]) + Lp = object_str(self.Lp, ["landmarks", "landmarks"]) + L = object_str(self.L, ["cells", "ranks"]) + nn_distances = object_str(self.nn_distances, ["cells"]) + initial_value = object_str(self.initial_value, ["ranks"]) + normalize_per_time_point = object_str( + self.normalize_per_time_point, ["time points"] + ) + d = object_str(self.d, ["cells"]) string = ( f"{name}(" - f"cov_func_curry={self.cov_func_curry}, " - f"n_landmarks={self.n_landmarks}, " - f"rank={self.rank}, " - f"method='{self.method}', " - f"jitter={self.jitter}, " - f"optimizer='{self.optimizer}', " - f"n_iter={self.n_iter}, " - f"init_learn_rate={self.init_learn_rate}, " - f"landmarks={self.landmarks}, " + f"\n cov_func_curry={self.cov_func_curry}," + f"\n n_landmarks={self.n_landmarks}," + f"\n rank={self.rank}," + f"\n gp_type={self.gp_type}," + f"\n jitter={self.jitter}, " + f"\n optimizer={self.optimizer}," + f"\n landmarks={landmarks}," + f"\n nn_distances={nn_distances}," + f"\n normalize_per_time_point={normalize_per_time_point}," + f"\n d={d}," + f"\n mu={self.mu}," + f"\n ls={self.ls}," + f"\n ls_time={self.ls_time}," + f"\n ls_factor={self.ls_factor}," + f"\n ls_time_factor={self.ls_time_factor}," + f"\n density_estimator_kwargs={self.density_estimator_kwargs}," + f"\n cov_func={self.cov_func}," + f"\n Lp={Lp}," + f"\n L={L}," + f"\n initial_value={initial_value}," + f"\n predictor_with_uncertainty={self.predictor_with_uncertainty}," + f"\n jit={self.jit}," + f"\n check_rank={self.check_rank}," + "\n)" ) - if self.nn_distances is None: - string += "nn_distances=None, " - else: - string += "nn_distances=nn_distances, " - string += ( - f"d={self.d}, " - f"mu={self.mu}, " - f"ls={self.ls}, " - f"ls_time={self.ls_time}, " - f"cov_func={self.cov_func}, " - ) - if self.L is None: - string += "L=None, " - else: - string += "L=L, " - if self.initial_value is None: - string += "initial_value=None, " - else: - string += "initial_value=initial_value, " - string += f"jit={self.jit}" ")" return string def _compute_d(self): @@ -336,13 +390,28 @@ def _compute_loss_func(self): def _compute_nn_distances(self): x = self.x normalize_per_time_point = self.normalize_per_time_point + d = self.d logger.info("Computing nearest neighbor distances within time points.") nn_distances = compute_nn_distances_within_time_points( x, + d=d, normalize=normalize_per_time_point, ) return nn_distances + def _compute_ls(self): + nn_distances = self.nn_distances + normalized = self.normalize_per_time_point + if normalized is not False and normalized is not None: + logger.info( + "Computing non-normalized nn_distances for length scale heuristic." + ) + x = self.x + nn_distances = compute_nn_distances_within_time_points(x, normalize=False) + ls = compute_ls(nn_distances) + ls *= self.ls_factor + return ls + def _compute_ls_time(self): nn_distances = self.nn_distances x = self.x @@ -375,7 +444,7 @@ def _compute_ls_time(self): "Storing `self.densities`, `self.predictors`, and `self.numeric_stages`." ) ls, self.densities, self.predictors, self.numeric_stages = ls - ls *= self.ls_factor_times + ls *= self.ls_time_factor return ls def _compute_landmarks(self): @@ -414,20 +483,32 @@ def _set_log_density_func(self): x = self.x landmarks = self.landmarks pre_transformation = self.pre_transformation + pre_transformation_std = self.pre_transformation_std log_density_x = self.log_density_x mu = self.mu cov_func = self.cov_func + L = self.L + Lp = self.Lp jitter = self.jitter + with_uncertainty = self.predictor_with_uncertainty + normalize = self.normalize_per_time_point logger.info("Computing predictive function.") - log_density_func = compute_conditional_mean_times( + log_density_func = compute_conditional_times( x, landmarks, pre_transformation, + pre_transformation_std, log_density_x, mu, cov_func, + L, + Lp, + sigma=None, jitter=jitter, + y_is_mean=True, + with_uncertainty=with_uncertainty, ) + log_density_func.n_obs = compute_average_cell_count(x, normalize) self.log_density_func = log_density_func def prepare_inference(self, x, times=None): @@ -471,13 +552,18 @@ def prepare_inference(self, x, times=None): raise ValueError(message) x = self.set_x(x) - self._prepare_attribute("nn_distances") + self._prepare_attribute("n_landmarks") + self._prepare_attribute("rank") + self._prepare_attribute("gp_type") + self._validate_parameter() self._prepare_attribute("d") + self._prepare_attribute("nn_distances") self._prepare_attribute("mu") self._prepare_attribute("ls") self._prepare_attribute("ls_time") self._prepare_attribute("cov_func") self._prepare_attribute("landmarks") + self._prepare_attribute("Lp") self._prepare_attribute("L") self._prepare_attribute("initial_value") self._prepare_attribute("transform") @@ -566,23 +652,18 @@ def fit(self, x=None, times=None, build_predict=True): @property def predict(self): R""" - An instance of the :class:`mellon.Predictor` that predicts the log density at each point in x. + An instance of the :class:`mellon.base_predictor.PredictorTime` that predicts the + log density at each point in x and time point time. The instance contains a __call__ method which can be used to predict the log density. This instance also supports serialization features which allows for saving - and loading the predictor state. Refer to mellon.Predictor documentation for more details. - - Note that the last column of the input array `x` should contain the time information. - - Parameters - ---------- - x : array-like - The new data to predict, where the last column should contain the time information. + and loading the predictor state. Refer to :class:`mellon.base_predictor.PredictorTime` + documentation for more details. Returns ------- - log_density : array-like - The log density at each test point in `x`. + mellon.base_predictor.PredictorTime + A predictor instance that computes the log density at each new data point. Example ------- @@ -604,7 +685,7 @@ def fit_predict(self, x=None, times=None, build_predict=False): """ if x is not None: x = _validate_time_x(x, times) - if self.x is not None and self.x is not x: + if self.x is not None and x is not None and self.x is not x: message = "self.x has been set already, but is not equal to the argument x." error = ValueError(message) logger.error(error) diff --git a/mellon/util.py b/mellon/util.py index 4c714cf..a9e3f7a 100644 --- a/mellon/util.py +++ b/mellon/util.py @@ -2,7 +2,9 @@ import logging import functools import inspect +from typing import List from inspect import Parameter +from enum import Enum from jax.config import config as jaxconfig from jax.numpy import ( @@ -23,10 +25,11 @@ ndarray, array, isscalar, - exp, + where, ) from numpy import integer, floating from jax.numpy import sum as arraysum +from jax.numpy import diag as diagonal from jax.numpy.linalg import norm, lstsq, matrix_rank from jax.scipy.special import gammaln from jax import vmap, jit @@ -34,21 +37,12 @@ from .validation import _validate_array +logger = logging.getLogger("mellon") + DEFAULT_JITTER = 1e-6 DEFAULT_RANK_TOL = 5e-1 -def Exp(func): - """ - Function wrapper, making a function that returns the exponent of the wrapped function. - """ - - def new_func(x): - return exp(func(x)) - - return new_func - - def _None_to_str(v): if v is None: return "None" @@ -80,6 +74,8 @@ def make_serializable(x): return {"type": "slice", "data": dat} elif isinstance(x, dict): return {"type": "dict", "data": {k: make_serializable(v) for k, v in x.items()}} + elif isinstance(x, set): + return {"type": "set", "data": [make_serializable(v) for v in x]} else: return _None_to_str(x) @@ -113,6 +109,8 @@ def deserialize(serializable_x): return slice(*dat) elif data_type == "dict": return {k: deserialize(v) for k, v in serializable_x["data"].items()} + elif data_type == "set": + return {deserialize(v) for v in serializable_x["data"]} else: return _str_to_None(serializable_x) @@ -178,14 +176,16 @@ def make_multi_time_argument(func): Examples -------- - class MyClass: - @make_multi_time_argument - def method(self, x, time=None): - return x + time - - my_object = MyClass() - print(my_object.method(1, multi_time=np.array([1, 2, 3]))) - # Output: array([2, 3, 4]) + .. code-block:: python + + class MyClass: + @make_multi_time_argument + def method(self, x, time=None): + return x + time + + my_object = MyClass() + print(my_object.method(1, multi_time=np.array([1, 2, 3]))) + # Output: array([2, 3, 4]) """ sig = inspect.signature(func) new_params = list(sig.parameters.values()) + [ @@ -231,6 +231,41 @@ def stabilize(A, jitter=DEFAULT_JITTER): return A + eye(n) * jitter +def add_variance(K, M=None, jitter=DEFAULT_JITTER): + R""" + Computes :math:`K + MM^T` where the diagonal of :math:`MM^T` is + at least `jitter`. This function stabilizes :math:`K` for the + Cholesky decomposition if not already stable enough through adding :math:`MM^T`. + + Parameters + ---------- + K : array_like, shape (n, n) + A covariance matrix. + M : array_like, shape (n, p), optional + Left factor of additional variance. Default is 0. + jitter : float, optional + A small number to stabilize the covariance matrix. Defaults to 1e-6. + + Returns + ------- + combined_covariance : array_like, shape (n, n) + A combined covariance matrix that is more stably positive definite. + + Notes + ----- + If `M` is None, the function will add the jitter to the diagonal of `K` to + make it more stable. Otherwise, it will add :math:`MM^T` and correct the + diagonal based on the `jitter` parameter. + """ + if M is None: + K = stabilize(K, jitter=jitter) + else: + noise = M.dot(M.T) + diff = where(diagonal(noise) < jitter, jitter - diagonal(noise), 0) + K = K + noise + diagonal(diff) + return K + + def mle(nn_distances, d): R""" Nearest Neighbor distribution maximum likelihood estimate for log density @@ -346,6 +381,13 @@ def local_dimensionality(x, k=30, x_query=None, neighbor_idx=None): This function computes the local fractal dimension of a dataset at query points. It uses nearest neighbors and fits a line in log-log space to estimate the fractal dimension. """ + if k > x.shape[0]: + logger.warning( + f"Number of nearest neighbors (k={k}) is " + f"greater than the number of samples ({x.shape[0]}). " + "Setting k to the number of samples." + ) + k = x.shape[0] if neighbor_idx is None: if x_query is None: x_query = x @@ -379,13 +421,11 @@ class Log(object): def __new__(cls): """Return the singelton Logger.""" if not hasattr(cls, "logger"): - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) cls.handler = logging.StreamHandler(sys.stdout) formatter = logging.Formatter("[%(asctime)s] [%(levelname)-8s] %(message)s") cls.handler.setFormatter(formatter) logger.addHandler(cls.handler) - logger.propagate = False cls.logger = logger return cls.logger @@ -417,3 +457,123 @@ def set_jax_config(enable_x64=True, platform_name="cpu"): """ jaxconfig.update("jax_enable_x64", enable_x64) jaxconfig.update("jax_platform_name", platform_name) + + +class GaussianProcessType(Enum): + """ + Defines types of Gaussian Process (GP) computations for various estimators within the mellon library: + :class:`mellon.model.DensityEstimator`, :class:`mellon.model.FunctionEstimator`, + :class:`mellon.model.DimensionalityEstimator`, :class:`mellon.model.TimeSensitiveDensityEstimator`. + + This enum can be passed through the `gp_type` attribute to the mentioned estimators. + If a string representing one of these values is passed alternatively, the + :func:`from_string` method is called to convert it to a `GaussianProcessType`. + + Options are 'full', 'full_nystroem', 'sparse_cholesky', 'sparse_nystroem'. + """ + + FULL = "full" + FULL_NYSTROEM = "full_nystroem" + SPARSE_CHOLESKY = "sparse_cholesky" + SPARSE_NYSTROEM = "sparse_nystroem" + + @staticmethod + def from_string(s: str, optional: bool = False): + """ + Converts a string to a GaussianProcessType object or raises an error. + + Parameters + ---------- + s : str + The type of Gaussian Process (GP). Options are: + - 'full': None-sparse GP + - 'full_nystroem': Sparse GP with Nyström rank reduction + - 'sparse_cholesky': Sparse GP using landmarks/inducing points + - 'sparse_nystroem': Sparse GP along with an improved Nyström rank reduction + optional : bool, optional + Specifies if the input is optional. Returns None if True and input is None. + + Returns + ------- + GaussianProcessType + Corresponding Gaussian Process type. + + Raises + ------ + ValueError + If the input does not correspond to any known Gaussian Process type. + """ + + if s is None: + if optional: + return None + else: + logger.error("Gaussian process type must be specified but is None.") + raise ValueError("Gaussian process type must be specified but is None.") + + if isinstance(s, GaussianProcessType): + return s + + normalized_input = s.lower().replace(" ", "_") + + # Try to match the exact Enum value + for gp_type in GaussianProcessType: + if gp_type.value == normalized_input: + logger.info(f"Gaussian Process type: {gp_type.value}") + return gp_type + + # If no exact match, try partial matching by finding the closest match + for gp_type in GaussianProcessType: + if normalized_input in gp_type.value: + logger.warning( + f"Partial match found for Gaussian Process type: {gp_type.value}. Input was: {s}" + ) + return gp_type + + message = f"Unknown Gaussian Process type: {s}" + logger.error(message) + raise ValueError(message) + + +def object_str(obj: object, dim_names: List[str] = None) -> str: + """ + Generate a concise string representation of metadata for array-like objects. + + Parameters + ---------- + obj : object + Object for which to generate metadata string. + + dim_names : list of str, optional + Names for dimensions, used for array-like objects. + + Returns + ------- + str + Metadata string. + + Examples + -------- + >>> object_metadata_str(np.array([[1, 2], [3, 4]]), dim_names=['row', 'col']) + '' + + >>> object_metadata_str(np.array([1, 2, 3]), dim_names=['element']) + '' + + >>> object_metadata_str("hello") + 'hello' + """ + if hasattr(obj, "shape") and hasattr(obj, "dtype"): + dims = obj.shape + if dim_names: + dim_strs = [f"{dim:,} {name}" for dim, name in zip(dims, dim_names)] + else: + dim_strs = [f"{dim:,}" for dim in dims] + + for i in range(len(dim_strs), len(dims)): + dim_strs.append(f"{dims[i]} dimension {i + 1}") + + dim_str = " x ".join(dim_strs) + return f"" + else: + return str(obj) diff --git a/mellon/validation.py b/mellon/validation.py index d579974..d01b07e 100644 --- a/mellon/validation.py +++ b/mellon/validation.py @@ -1,8 +1,11 @@ from collections.abc import Iterable +import logging from jax.numpy import asarray, concatenate, isscalar, full, ndarray from jax.errors import ConcretizationTypeError +logger = logging.getLogger(__name__) + def _validate_time_x(x, times=None, n_features=None, cast_scalar=False): """ @@ -86,6 +89,29 @@ def _validate_time_x(x, times=None, n_features=None, cast_scalar=False): def _validate_float_or_int(value, param_name, optional=False): + """ + Validates whether a given value is a float or an integer. + + Parameters + ---------- + value : float, int, or string, or None + The value to be validated. It should be a float, integer, or convertible to a float. + param_name : str + The name of the parameter to be used in the error message. + optional : bool, optional + Whether the value is optional. If optional and value is None, returns None. Default is False. + + Returns + ------- + float or int + The validated value as float or int. + + Raises + ------ + ValueError + If the value is not float, int, convertible to float, and not None when not optional. + """ + if value is None and optional: return None @@ -101,6 +127,29 @@ def _validate_float_or_int(value, param_name, optional=False): def _validate_positive_float(value, param_name, optional=False): + """ + Validates whether a given value is a positive float. + + Parameters + ---------- + value : float, int, or string, or None + The value to be validated. It should be a positive float or convertible to a positive float. + param_name : str + The name of the parameter to be used in the error message. + optional : bool, optional + Whether the value is optional. If optional and value is None, returns None. Default is False. + + Returns + ------- + float + The validated value as a positive float. + + Raises + ------ + ValueError + If the value is not a positive float, not convertible to a positive float, and not None when not optional. + """ + if value is None and optional: return None @@ -164,6 +213,29 @@ def _validate_float(value, param_name, optional=False): def _validate_positive_int(value, param_name, optional=False): + """ + Validates whether a given value is a positive integer. + + Parameters + ---------- + value : int or None + The value to be validated. It should be a positive integer. + param_name : str + The name of the parameter to be used in the error message. + optional : bool, optional + Whether the value is optional. If optional and value is None, returns None. Default is False. + + Returns + ------- + int or None + The validated value as a positive integer, or None if the value is optional and None. + + Raises + ------ + ValueError + If the value is not a positive integer and not None when not optional. + """ + if optional and value is None: return None if not isinstance(value, int) or value < 0: @@ -233,7 +305,39 @@ def _validate_array(iterable, name, optional=False, ndim=None): return array -def _validate_bool(value, name): +def _validate_bool(value, name, optional=False): + """ + Validates whether a given value is a boolean. + + Parameters + ---------- + value : any + The value to be validated. + + name : str + The name of the parameter to be used in the error message. + + optional : bool, optional + If True, 'value' can be None and the function will return None in this case. + If False and 'value' is None, a TypeError is raised. Defaults to False. + + Returns + ------- + bool + The validated value as a boolean. + + Raises + ------ + TypeError + If the value is not of type bool. + """ + + if value is None: + if optional: + return None + else: + raise TypeError(f"'{name}' can't be None.") + if not isinstance(value, bool): raise TypeError(f"{name} should be of type bool, got {type(value)} instead.") @@ -241,6 +345,31 @@ def _validate_bool(value, name): def _validate_string(value, name, choices=None): + """ + Validates whether a given value is a string and optionally whether it is in a set of choices. + + Parameters + ---------- + value : any + The value to be validated. + name : str + The name of the parameter to be used in the error message. + choices : list of str, optional + A list of valid string choices. If provided, the value must be one of these choices. + + Returns + ------- + str + The validated value as a string. + + Raises + ------ + TypeError + If the value is not of type str. + ValueError + If the value is not one of the choices (if provided). + """ + if not isinstance(value, str): raise TypeError(f"{name} should be of type str, got {type(value)} instead.") @@ -250,52 +379,55 @@ def _validate_string(value, name, choices=None): return value -def _validate_float_or_iterable_numerical(value, name, optional=False): +def _validate_float_or_iterable_numerical(value, name, optional=False, positive=False): + """ + Validates whether a given value is a float, integer, or iterable of numerical values, + with an option to check for non-negativity. + + Parameters + ---------- + value : int, float, Iterable or None + The value to be validated. + name : str + The name of the parameter to be used in the error message. + optional : bool, optional + Whether the value is optional. If optional and value is None, returns None. Default is False. + positive : bool, optional + Whether to validate that the value is non-negative. Default is False. + + Returns + ------- + float or ndarray + The validated value as a float or a numeric array. + + Raises + ------ + TypeError + If the value is not of type int, float or iterable. + ValueError + If the value could not be converted to a numeric array (if iterable) or if the value is negative (if positive is True). + """ + if value is None and optional: return None if isinstance(value, (int, float)): - return float(value) + value = float(value) + if positive and value < 0: + raise ValueError(f"{name} should be a non-negative number or array") + return value - if isinstance(value, Iterable): - try: - return asarray(value, dtype=float) - except Exception: - raise ValueError(f"Could not convert {name} to a numeric array.") + if isinstance(value, Iterable) and not isinstance(value, str): + result = asarray(value, dtype=float) + if positive and (result < 0).any(): + raise ValueError(f"All elements in {name} should be non-negative") + return result raise TypeError( f"{name} should be of type int, float or iterable, got {type(value)} instead." ) -def _validate_cov_func_curry(cov_func_curry, cov_func, param_name): - if cov_func_curry is None and cov_func is None: - raise ValueError( - "At least one of 'cov_func_curry' and 'cov_func' must not be None" - ) - - from .base_cov import Covariance - - if cov_func_curry is not None: - if not isinstance(cov_func_curry, type) or not issubclass( - cov_func_curry, Covariance - ): - raise ValueError(f"'{param_name}' must be a subclass of mellon.Covariance") - return cov_func_curry - - -def _validate_cov_func(cov_func, param_name, optional=False): - if cov_func is None and optional: - return None - from .base_cov import Covariance - - if not isinstance(cov_func, Covariance): - raise ValueError( - f"'{param_name}' must be an instance of a subclass of mellon.Covariance" - ) - return cov_func - - def _validate_1d(x): """ Validates that `x` can be cast to a JAX array with exactly 1 dimension and float data type. diff --git a/notebooks/time-series_tutorial.ipynb b/notebooks/time-series_tutorial.ipynb index 48ea509..df7e280 100644 --- a/notebooks/time-series_tutorial.ipynb +++ b/notebooks/time-series_tutorial.ipynb @@ -159,14 +159,6 @@ "ad" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "391b355d-9f3c-40bc-b02c-5267402f028e", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "id": "37de312b", @@ -378,7 +370,9 @@ "source": [ "## Step 4: Density Evaluation and Plotting\n", "\n", - "The cell-state density can be evaluated for any time point. For instance, we can evaluate the\n", + "The `density_predictor` is a subclass of the\n", + "[mellon.base_predictor.PredictorTime](https://mellon.readthedocs.io/en/latest/predictor.html#mellon.base_predictor.PredictorTime).\n", + "Thus the cell-state density can be evaluated for any time point. For instance, we can evaluate the\n", "density for all cell states at stage E7.10 wich is notably not covered by a sample." ] }, @@ -731,7 +725,11 @@ "\n", "## Step 7: Saving the Predictor\n", "\n", - "The density predictor can be serialized and saved for future use. It can be stored in the AnnData `uns` attribute, and the AnnData can be saved with the `write` method. The predictor can be reconstructed using the `from_dict` method of `mellon.Predictor`." + "The density predictor can be serialized and saved for future use.\n", + "It can be stored in the AnnData `uns` attribute, and the AnnData can be saved with the `write` method.\n", + "The predictor can be reconstructed using the\n", + "[from_dict](https://mellon.readthedocs.io/en/latest/predictor.html#mellon.Predictor.from_dict)\n", + "method of [mellon.Predictor](https://mellon.readthedocs.io/en/latest/predictor.html#mellon.Predictor)." ] }, { @@ -820,9 +818,9 @@ ], "metadata": { "kernelspec": { - "display_name": "mellon-test", + "display_name": "mellon_v2", "language": "python", - "name": "mellon-test" + "name": "mellon_v2" }, "language_info": { "codemirror_mode": { @@ -834,11 +832,109 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.10.10" }, "widgets": { "application/vnd.jupyter.widget-state+json": { - "state": {}, + "state": { + "009d4a832bc547eda98b7134b8bfa189": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "0956b753c5db4344907cbdc9934269bd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "1b675b44738948378acc3e81a5a190e1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "children": [ + "IPY_MODEL_339b3c18a5a249ba93043baf40ec7e4f", + "IPY_MODEL_22e8e9adc7de4f88bdaae6afa680dc7c", + "IPY_MODEL_6ccb85f98c8e4378b7a9dc92691a001c" + ], + "layout": "IPY_MODEL_7b4c74dd7d054a5c821a9a25958b1258" + } + }, + "22e8e9adc7de4f88bdaae6afa680dc7c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "bar_style": "success", + "layout": "IPY_MODEL_009d4a832bc547eda98b7134b8bfa189", + "max": 18684354899, + "style": "IPY_MODEL_fe530b4f59bb414796f46b1e5716f344", + "value": 18684354899 + } + }, + "339b3c18a5a249ba93043baf40ec7e4f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "layout": "IPY_MODEL_ee6bc90a82fb468c9cb840306406f599", + "style": "IPY_MODEL_0956b753c5db4344907cbdc9934269bd", + "value": "100%" + } + }, + "604ee0f3c5c542c2a39e2d2ba52aa852": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "6ccb85f98c8e4378b7a9dc92691a001c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "layout": "IPY_MODEL_e0bae8abaf084635bc297e90a0330ad9", + "style": "IPY_MODEL_604ee0f3c5c542c2a39e2d2ba52aa852", + "value": " 17.4G/17.4G [16:58<00:00, 20.3MB/s]" + } + }, + "7b4c74dd7d054a5c821a9a25958b1258": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "e0bae8abaf084635bc297e90a0330ad9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "ee6bc90a82fb468c9cb840306406f599": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": {} + }, + "fe530b4f59bb414796f46b1e5716f344": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "description_width": "" + } + } + }, "version_major": 2, "version_minor": 0 } diff --git a/notebooks/trajectory-trends_tutorial.ipynb b/notebooks/trajectory-trends_tutorial.ipynb index 4fb0593..c2e8a48 100644 --- a/notebooks/trajectory-trends_tutorial.ipynb +++ b/notebooks/trajectory-trends_tutorial.ipynb @@ -293,7 +293,9 @@ "id": "8129bf9a", "metadata": {}, "source": [ - "Palantir also offers a wrapper for the `mellon.FunctionEstimator` to run the trend computation on all branches simultaneously and store them in `ad.varm[\"gene_trends_\"]`:" + "Palantir also offers a wrapper for the\n", + "[`mellon.FunctionEstimator`](https://mellon.readthedocs.io/en/uncertainty/model.html#mellon.model.FunctionEstimator)\n", + "to run the trend computation on all branches simultaneously and store them in `ad.varm[\"gene_trends_\"]`:" ] }, { diff --git a/setup.py b/setup.py index 13b8bbb..51dc964 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ def get_version(rel_path): install_requires=[ "jax", "jaxopt", - "numpy<1.25.0", # Specific version required due to compatibility with jax<0.4.15 + "numpy<1.25.0", # Specific version required due to compatibility with jax<0.4.15 "scikit-learn", ], classifiers=[ diff --git a/tests/density_estimator.py b/tests/density_estimator.py index 6010d8e..a085eaa 100644 --- a/tests/density_estimator.py +++ b/tests/density_estimator.py @@ -31,12 +31,18 @@ def test_density_estimator_properties(common_setup): X, _, _, relative_err, est, _ = common_setup n, d = X.shape + len_str = len(str(mellon.DensityEstimator())) + assert len_str > 0, "The model should have a string representation." + pred_log_dens = est.predict(X) assert relative_err(pred_log_dens) < 1e-5, ( "The predicive function should be consistent with the density on " "the training samples." ) + pred_str = len(str(est.predict)) + assert pred_str > 0, "The predictor should have a string representation." + grads = est.predict.gradient(X) assert ( grads.shape == X.shape @@ -53,7 +59,9 @@ def test_density_estimator_properties(common_setup): assert sng.shape == (n,), "There should be one sign for each hessian determinan." assert ld.shape == (n,), "There should be one value for each hessian determinan." - assert len(str(est)) > 0, "The model should have a string representation." + assert ( + len(str(est)) > len_str + ), "The model should have a longer string representation after fitting." def test_density_estimator_fractal_dimension(common_setup): @@ -81,19 +89,17 @@ def test_density_estimator_optimizers(common_setup): @pytest.mark.parametrize( - "rank, method, n_landmarks, err_limit", + "rank, n_landmarks, err_limit", [ - (1.0, "percent", 0, 1e-1), - (1.0, "percent", 10, 2e-1), - (0.99, "percent", 80, 2e-1), + (1.0, 0, 1e-1), + (1.0, 10, 2e-1), + (0.99, 80, 2e-1), ], ) -def test_density_estimator_approximations( - common_setup, rank, method, n_landmarks, err_limit -): +def test_density_estimator_approximations(common_setup, rank, n_landmarks, err_limit): X, _, _, relative_err, _, _ = common_setup - est = mellon.DensityEstimator(rank=rank, method=method, n_landmarks=n_landmarks) + est = mellon.DensityEstimator(rank=rank, n_landmarks=n_landmarks) est.fit(X) dens_appr = est.predict(X) assert ( @@ -117,17 +123,123 @@ def test_density_estimator_serialization(common_setup, rank, n_landmarks, compre est = mellon.DensityEstimator(rank=rank, n_landmarks=n_landmarks) est.fit(X) dens_appr = est.predict(X) + norm_dens_appr = est.predict(X, normalize=True) + is_close = jnp.all(jnp.isclose(dens_appr, norm_dens_appr)) + assert not is_close, "The normalized and non-normalized predictions should differ." # Test serialization + json_string = est.predict.to_json() + assert isinstance( + json_string, str + ), "Json string should be returned if no filename is given." est.predict.to_json(test_file, compress=compress) + est.predict.to_json(str(test_file), compress=compress) logger.info(f"Serialized the predictor and saved it to {test_file}.") predictor = mellon.Predictor.from_json(test_file, compress=compress) logger.info("Deserialized the predictor from the JSON file.") reprod = predictor(X) + norm_reprod = predictor(X, normalize=True) + logger.info("Made a prediction with the deserialized predictor.") + is_close = jnp.all(jnp.isclose(dens_appr, reprod)) + norm_is_close = jnp.all(jnp.isclose(norm_dens_appr, norm_reprod)) + assert_msg = "Serialized + deserialized predictor should produce the same results." + assert is_close and norm_is_close, assert_msg + logger.info( + "Assertion passed: the deserialized predictor produced the expected results." + ) + logger.info("Serializing deserialized predictor again.") + predictor.to_json(test_file, compress=compress) + # test backwards compatibility + edict = predictor.to_dict() + edict["metadata"]["module_version"] = "1.3.1" + edict["data"].pop("n_obs") + edict["data"].pop("_state_variables") + mellon.Predictor.from_dict(edict) + reprod = predictor(X, normalize=False) + is_close = jnp.all(jnp.isclose(dens_appr, reprod)) + assert ( + is_close + ), "Deserialized predictor of mellon 1.3.1 should produce the same results." + + +def test_density_estimator_without_uncertainty(common_setup): + X, _, _, _, est, _ = common_setup + + with pytest.raises(ValueError): + est.predict.covariance(X) + with pytest.raises(ValueError): + est.predict.mean_covariance(X) + with pytest.raises(ValueError): + est.predict.uncertainty(X) + + +@pytest.mark.parametrize( + "rank, n_landmarks, compress", + [ + (1.0, 0, None), + (0.99, 0, None), + (1.0, 10, None), + (0.99, 80, None), + ], +) +def test_density_estimator_serialization_with_uncertainty( + common_setup, rank, n_landmarks, compress +): + X, test_file, logger, _, _, _ = common_setup + n = X.shape[0] + + est = mellon.DensityEstimator( + rank=rank, + n_landmarks=n_landmarks, + optimizer="advi", + predictor_with_uncertainty=True, + ) + est.fit(X) + dens_appr = est.predict(X) + covariance = est.predict.covariance(X) + assert covariance.shape == ( + n, + ), "The diagonal of the covariance matrix should be reported." + mean_covariance = est.predict.mean_covariance(X) + assert mean_covariance.shape == ( + n, + ), "The diagonal of the mean covariance should be reported." + uncertainty_pred = est.predict.uncertainty(X) + assert uncertainty_pred.shape == (n,), "One value per sample should be reported." + + full_covariance = est.predict.covariance(X, diag=False) + assert full_covariance.shape == ( + n, + n, + ), "The full covariance matrix should be repoorted." + full_mean_covariance = est.predict.mean_covariance(X, diag=False) + assert full_mean_covariance.shape == ( + n, + n, + ), "The full mean covariance should be repoorted." + full_uncertainty_pred = est.predict.uncertainty(X, diag=False) + assert full_uncertainty_pred.shape == ( + n, + n, + ), "The full covariance should be reported." + + # Test serialization + est.predict.to_json(test_file, compress=compress) + logger.info( + f"Serialized the predictor with uncertainty and saved it to {test_file}." + ) + predictor = mellon.Predictor.from_json(test_file, compress=compress) + logger.info("Deserialized the predictor with uncertainty from the JSON file.") + reprod = predictor(X) logger.info("Made a prediction with the deserialized predictor.") is_close = jnp.all(jnp.isclose(dens_appr, reprod)) assert_msg = "Serialized + deserialized predictor should produce the same results." assert is_close, assert_msg + reprod_uncertainty = predictor.uncertainty(X) + logger.info("Made a uncertainty prediction with the deserialized predictor.") + is_close = jnp.all(jnp.isclose(uncertainty_pred, reprod_uncertainty)) + assert_msg = "Serialized + deserialized predictor should produce the same uncertainty results." + assert is_close, assert_msg logger.info( "Assertion passed: the deserialized predictor produced the expected results." ) @@ -161,8 +273,53 @@ def test_density_estimator_single_dimension(common_setup): d1_pred = est.fit_predict(X[:, 0]) assert d1_pred.shape == (n,), "There should be one result per sample." - est = mellon.DensityEstimator(rank=1.0, method="percent", n_landmarks=0) + est = mellon.DensityEstimator(rank=1.0, n_landmarks=0) d1_pred_full = est.fit_predict(X[:, 0]) assert ( jnp.std(d1_pred - d1_pred_full) < 1e-2 ), "The scalar state function estimations be consistent under approximation." + + +def test_density_estimator_errors(common_setup): + X, test_file, _, _, _, _ = common_setup + lX = jnp.concatenate( + [ + X, + ] + * 26, + axis=1, + ) + est = mellon.DensityEstimator() + + with pytest.raises(ValueError): + est.fit_predict() + with pytest.raises(ValueError): + est.fit(lX) + with pytest.raises(ValueError): + est.fit(None) + est.set_x(X) + with pytest.raises(ValueError): + est.prepare_inference(lX) + loss_func, initial_value = est.prepare_inference(None) + est.run_inference(loss_func, initial_value, "advi") + est.process_inference(est.pre_transformation) + with pytest.raises(ValueError): + est.fit_predict(lX) + predictor = est.predict + with pytest.raises(ValueError): + predictor(X[:, :-1]) + with pytest.raises(ValueError): + predictor.covariance(X[:, :-1]) + with pytest.raises(ValueError): + predictor.mean_covariance(X[:, :-1]) + with pytest.raises(ValueError): + predictor.uncertainty(X[:, :-1]) + with pytest.raises(ValueError): + predictor.to_json(test_file, compress="bad_type") + est.fit_predict() + est.predict.n_obs = None + with pytest.raises(ValueError): + est.predict(X, normalize=True) + est = mellon.DensityEstimator(predictor_with_uncertainty=True) + with pytest.raises(ValueError): + est.fit(X) diff --git a/tests/dimensionality_estimator.py b/tests/dimensionality_estimator.py index f2ad788..4a9f000 100644 --- a/tests/dimensionality_estimator.py +++ b/tests/dimensionality_estimator.py @@ -5,10 +5,11 @@ @pytest.fixture -def common_setup_dim_estimator(): +def common_setup_dim_estimator(tmp_path): n = 100 d = 2 seed = 535 + test_file = tmp_path / "predictor.json" key = jax.random.PRNGKey(seed) L = jax.random.uniform(key, (d, d)) cov = L.T.dot(L) @@ -23,11 +24,11 @@ def relative_err(dim): diff_dim = jnp.std(local_dim - dim) / dim_std return diff_dim - return X, local_dim, relative_err, est, dim_std + return X, test_file, local_dim, relative_err, est, dim_std def test_dimensionality_estimator_properties(common_setup_dim_estimator): - X, local_dim, relative_err, est, _ = common_setup_dim_estimator + X, _, local_dim, relative_err, est, _ = common_setup_dim_estimator n, d = X.shape pred = est.predict(X) @@ -56,8 +57,60 @@ def test_dimensionality_estimator_properties(common_setup_dim_estimator): assert ld.shape == (n,), "There should be one value for each sample." +@pytest.mark.parametrize( + "rank, n_landmarks, compress", + [ + (1.0, 0, None), + (0.99, 0, None), + (1.0, 10, None), + (0.99, 80, None), + ], +) +def test_dimensionality_estimator_serialization_with_uncertainty( + common_setup_dim_estimator, rank, n_landmarks, compress +): + X, test_file, _, _, _, _ = common_setup_dim_estimator + n = X.shape[0] + + est = mellon.DimensionalityEstimator( + rank=rank, + n_landmarks=n_landmarks, + optimizer="advi", + predictor_with_uncertainty=True, + ) + est.fit(X) + dens_appr = est.predict(X) + log_dens_appr = est.predict(X, logscale=True) + is_close = jnp.all(jnp.isclose(dens_appr, jnp.exp(log_dens_appr))) + assert ( + is_close + ), "The exp of the log scale prediction should mix the original prediction." + covariance = est.predict.covariance(X) + assert covariance.shape == ( + n, + ), "The diagonal of the covariance matrix should be repoorted." + mean_covariance = est.predict.mean_covariance(X) + assert mean_covariance.shape == ( + n, + ), "The diagonal of the mean covariance should be repoorted." + uncertainty_pred = est.predict.uncertainty(X) + assert uncertainty_pred.shape == (n,), "One value per sample should be reported." + + # Test serialization + est.predict.to_json(test_file, compress=compress) + predictor = mellon.Predictor.from_json(test_file, compress=compress) + reprod = predictor(X) + is_close = jnp.all(jnp.isclose(dens_appr, reprod)) + assert_msg = "Serialized + deserialized predictor should produce the same results." + assert is_close, assert_msg + reprod_uncertainty = predictor.uncertainty(X) + is_close = jnp.all(jnp.isclose(uncertainty_pred, reprod_uncertainty)) + assert_msg = "Serialized + deserialized predictor should produce the same uncertainty results." + assert is_close, assert_msg + + def test_dimensionality_estimator_optimizer(common_setup_dim_estimator): - X, local_dim, relative_err, _, _ = common_setup_dim_estimator + X, _, local_dim, relative_err, _, _ = common_setup_dim_estimator adam_est = mellon.DimensionalityEstimator(optimizer="adam") adam_dim = adam_est.fit_predict(X) @@ -67,24 +120,48 @@ def test_dimensionality_estimator_optimizer(common_setup_dim_estimator): @pytest.mark.parametrize( - "rank, method, n_landmarks, err_limit", + "rank, n_landmarks, err_limit", [ - (1.0, "percent", 100, 1e0), - (1.0, "percent", 10, 1e0), - (0.99, "percent", 80, 1e0), - (50, "auto", 80, 1e0), + (1.0, 100, 1e0), + (1.0, 10, 1e0), + (0.99, 80, 1e0), + (50, 80, 1e0), ], ) def test_dimensionality_estimator_approximations( - common_setup_dim_estimator, rank, method, n_landmarks, err_limit + common_setup_dim_estimator, rank, n_landmarks, err_limit ): - X, local_dim, relative_err, _, _ = common_setup_dim_estimator + X, _, local_dim, relative_err, _, _ = common_setup_dim_estimator - est = mellon.DimensionalityEstimator( - rank=rank, method=method, n_landmarks=n_landmarks - ) + est = mellon.DimensionalityEstimator(rank=rank, n_landmarks=n_landmarks) est.fit(X) dim_appr = est.predict(X) assert ( relative_err(dim_appr) < err_limit ), "The approximation should be close to the default." + + +def test_dimensionality_estimator_errors(common_setup_dim_estimator): + X, _, _, _, _, _ = common_setup_dim_estimator + lX = jnp.concatenate( + [ + X, + ] + * 26, + axis=1, + ) + est = mellon.DimensionalityEstimator() + + with pytest.raises(ValueError): + est.fit_predict() + with pytest.raises(ValueError): + est.fit(None) + est.set_x(X) + with pytest.raises(ValueError): + est.prepare_inference(lX) + loss_func, initial_value = est.prepare_inference(None) + est.run_inference(loss_func, initial_value, "advi") + est.process_inference(est.pre_transformation) + with pytest.raises(ValueError): + est.fit_predict(lX) + est.fit_predict() diff --git a/tests/function_estimator.py b/tests/function_estimator.py index 1f25d52..ed4f45e 100644 --- a/tests/function_estimator.py +++ b/tests/function_estimator.py @@ -16,7 +16,7 @@ def function_estimator_setup(): noise = 1e-2 * jnp.sum(jnp.sin(X * 1e16), axis=1) noiseless_y = jnp.sum(jnp.sin(X / 2), axis=1) y = noiseless_y + noise - Y = jnp.stack([y, noiseless_y]) + Y = jnp.stack([y, noiseless_y], axis=1) return X, y, Y, noiseless_y @@ -24,7 +24,15 @@ def function_estimator_setup(): def test_function_estimator_prediction(function_estimator_setup): X, y, _, noiseless_y = function_estimator_setup n = X.shape[0] + + with pytest.raises(ValueError): + mellon.FunctionEstimator(gp_type="sparse_nystroem") + est = mellon.FunctionEstimator(sigma=1e-3) + + with pytest.raises(TypeError): + est.fit_predict() + pred = est.fit_predict(X, y) assert pred.shape == (n,), "There should be a predicted value for each sample." @@ -34,17 +42,38 @@ def test_function_estimator_prediction(function_estimator_setup): err = jnp.std(noiseless_y - pred) assert err < 1e-2, "The prediction should be close to the true value." + pred_self = est(X, y) + assert jnp.all( + jnp.isclose(pred, pred_self) + ), "__call__() shoud return the same as predict()" + + est.compute_conditional(y=y) + est.compute_conditional(x=y, y=y) + + with pytest.raises(ValueError): + est.compute_conditional(X) + + with pytest.raises(ValueError): + est.fit(X, y[:3]) + + with pytest.raises(ValueError): + est.fit_predict(X[:, :, None], y) + + with pytest.raises(ValueError): + est.fit_predict(X[:3, :], y) + def test_function_estimator_multi_fit_predict(function_estimator_setup): X, y, Y, _ = function_estimator_setup n = X.shape[0] est = mellon.FunctionEstimator(sigma=1e-3) - m_pred = est.multi_fit_predict(X, Y, X) + m_pred = est.fit_predict(X, Y, X) assert m_pred.shape == ( - 2, n, + 2, ), "There should be a value for each sample and location." + est.multi_fit_predict(X, Y.T, X) @pytest.mark.parametrize("n_landmarks, error_limit", [(0, 1e-4), (10, 1e-1)]) diff --git a/tests/parameters.py b/tests/parameters.py index fff57cb..1c7e4d4 100644 --- a/tests/parameters.py +++ b/tests/parameters.py @@ -1,5 +1,303 @@ -import mellon +import pytest +from enum import Enum import jax.numpy as jnp +from jax import random +import logging +import mellon +from mellon.parameters import ( + compute_n_landmarks, + compute_rank, + compute_nn_distances, + compute_gp_type, + compute_landmarks_rescale_time, + compute_nn_distances_within_time_points, + compute_d_factal, + compute_Lp, + compute_L, +) +from mellon.util import GaussianProcessType + + +def test_compute_landmarks_rescale_time(): + x = jnp.array([[1, 2], [3, 4], [3, 5]]) + + lm = compute_landmarks_rescale_time(x, 1, 1, n_landmarks=0) + assert lm is None, "Non should be returned if n_landmarks=0" + + # Testing input validation by passing negative length scales + with pytest.raises(ValueError): + compute_landmarks_rescale_time(x, -1, 1) + + with pytest.raises(ValueError): + compute_landmarks_rescale_time(x, 1, -1) + + lm = compute_landmarks_rescale_time(x, 1, 2, n_landmarks=2) + assert lm.shape == (2, 2), "`n_landmarks` landmars should be retuned." + + +def test_compute_nn_distances_within_time_points(): + # Test basic functionality without normalization + x = jnp.array([[1, 2, 0], [3, 4, 0], [5, 6, 1], [7, 8, 1]]) + result = compute_nn_distances_within_time_points(x) + assert result.shape == (4,) + + # Test behavior with insufficient samples at a given time point + x_single_sample = jnp.array([[1, 2, 0], [3, 4, 1], [5, 6, 2]]) + with pytest.raises(ValueError): + compute_nn_distances_within_time_points(x_single_sample) + + # Test functionality with the times array passed separately + x_without_times = jnp.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + times = jnp.array([0, 0, 1, 1]) + result_with_times = compute_nn_distances_within_time_points( + x_without_times, times=times + ) + assert jnp.all(result_with_times == result) + + +def test_compute_d_factal(caplog): + # Create a random key for jax.random + key = random.PRNGKey(0) + + # Create a random array using jax.random + x_2d = random.normal(key, shape=(100, 10)) + result_2d = compute_d_factal(x_2d) + assert isinstance(result_2d, float) + + # Test with 1D array (should return 1) + x_1d = random.normal(key, shape=(100,)) + assert compute_d_factal(x_1d) == 1 + + # Test with k > number of samples (expect a warning) + x_small = random.normal(key, shape=(5, 10)) + logger = logging.getLogger("mellon") + logger.propagate = True + with caplog.at_level(logging.WARNING, logger="mellon"): + compute_d_factal(x_small, k=10) + logger.propagate = False + assert "is greater than the number of samples" in caplog.text + + # Test with specific random seed + x_seed = random.normal(key, shape=(100, 10)) + result_seed = compute_d_factal(x_seed, seed=432) + assert isinstance(result_seed, float) + + # Test with n < number of samples + x_n = random.normal(key, shape=(1000, 10)) + result_n = compute_d_factal(x_n, n=500) + assert isinstance(result_n, float) + + # Test with invalid input (negative k) + with pytest.raises(ValueError): + compute_d_factal(x_2d, k=-5) + + +def test_compute_Lp(): + + # Generate some mock data and landmarks + x = jnp.array([[1, 2], [3, 4], [5, 6]]) + landmarks = jnp.array([[1, 2], [3, 4]]) + mock_cov_func = mellon.cov.Matern52(1) + + # Test 'full' Gaussian Process type + Lp = compute_Lp(x, mock_cov_func, gp_type="full") + assert Lp.shape == (3, 3) + assert isinstance(Lp, jnp.ndarray) + + # Test 'sparse_cholesky' with landmarks + Lp_sparse = compute_Lp( + x, mock_cov_func, gp_type="sparse_cholesky", landmarks=landmarks + ) + assert Lp_sparse.shape == (2, 2) + assert isinstance(Lp_sparse, jnp.ndarray) + + # Test full Nyström should return None + assert compute_Lp(x, mock_cov_func, gp_type="full_nystroem") is None + + # Test sparse Nyström should return None + assert compute_Lp(x, mock_cov_func, gp_type="sparse_nystroem") is None + + # Test with invalid GaussianProcessType + with pytest.raises(ValueError): + compute_Lp(x, mock_cov_func, gp_type="unknown_type") + + # Test without specifying gp_type (it should be inferred) + Lp_inferred = compute_Lp(x, mock_cov_func) + assert Lp_inferred is not None + assert isinstance(Lp_inferred, jnp.ndarray) + + # Test with custom sigma and jitter + Lp_custom = compute_Lp(x, mock_cov_func, sigma=0.1, jitter=0.001) + assert Lp_custom.shape == (3, 3) + assert isinstance(Lp_custom, jnp.ndarray) + + # Test with no landmarks for 'sparse_cholesky' + Lp_no_landmarks = compute_Lp(x, mock_cov_func, gp_type="sparse_cholesky") + assert Lp_no_landmarks.shape == (3, 3) + assert isinstance(Lp_no_landmarks, jnp.ndarray) + + +def test_compute_L(): + x = jnp.array([[1, 2], [3, 4], [5, 6], [8, 8]]) + landmarks = jnp.array([[1, 2], [3, 4], [5, 6]]) + mock_cov_func = mellon.cov.ExpQuad(1.1) + + # Test FULL type with Lp=None + L = compute_L(x, mock_cov_func, gp_type="full") + assert L.shape == (4, 4) + assert isinstance(L, jnp.ndarray) + + # Test FULL type with Lp as an array + Lp = jnp.array([[0.5, 0.1], [0.1, 0.5]]) + with pytest.raises(ValueError): + compute_L(x, mock_cov_func, gp_type="full", Lp=Lp) + + # Test FULL_NYSTROEM type + L = compute_L(x, mock_cov_func, gp_type="full_nystroem", rank=2) + assert L.shape == (4, 2) + assert isinstance(L, jnp.ndarray) + + # Test SPARSE_CHOLESKY with landmarks and Lp=None + L = compute_L(x, mock_cov_func, gp_type="sparse_cholesky", landmarks=landmarks) + assert L.shape == (4, 3) + assert isinstance(L, jnp.ndarray) + + # Test SPARSE_CHOLESKY with landmarks and Lp as an array + Lp = jnp.array([[0.5, 0.1], [0.1, 0.5]]) + L = compute_L( + x, + mock_cov_func, + gp_type="sparse_cholesky", + landmarks=landmarks[:2, :], + Lp=Lp, + ) + assert L.shape == (4, 2) + assert isinstance(L, jnp.ndarray) + + with pytest.raises(ValueError): + wrong_shape_Lp = jnp.array([[0.5, 0.1, 0.2], [0.1, 0.5, 0.2]]) + compute_L( + x, + mock_cov_func, + gp_type="sparse_cholesky", + landmarks=landmarks, + Lp=wrong_shape_Lp, + ) + + # Test SPARSE_NYSTROEM + L = compute_L( + x, mock_cov_func, gp_type="sparse_nystroem", landmarks=landmarks, rank=2 + ) + assert L.shape == (4, 2) + assert isinstance(L, jnp.ndarray) + + # Test with unknown gp_type + with pytest.raises(ValueError): + compute_L(x, mock_cov_func, gp_type="unknown") + + # Test with custom rank, sigma, and jitter + L = compute_L(x, mock_cov_func, rank=2, sigma=0.1, jitter=0.001) + assert L.shape == (4, 2) + assert isinstance(L, jnp.ndarray) + + +def test_gaussian_process_type(): + assert ( + GaussianProcessType.from_string("full") == GaussianProcessType.FULL + ), "Error converting 'full' to FULL enum." + assert ( + GaussianProcessType.from_string("full_nystroem") + == GaussianProcessType.FULL_NYSTROEM + ), "Error converting 'full_nystroem' to FULL_NYSTROEM enum." + assert ( + GaussianProcessType.from_string("sparse_cholesky") + == GaussianProcessType.SPARSE_CHOLESKY + ), "Error converting 'sparse_cholesky' to SPARSE_CHOLESKY enum." + assert ( + GaussianProcessType.from_string("sparse_nystroem") + == GaussianProcessType.SPARSE_NYSTROEM + ), "Error converting 'sparse_nystroem' to SPARSE_NYSTROEM enum." + + with pytest.raises(ValueError): + GaussianProcessType.from_string( + "unknown_type" + ), "Error was expected with unknown Gaussian Process type." + + partial_input = GaussianProcessType.from_string("sparse") + assert isinstance( + partial_input, GaussianProcessType + ), "Error converting partial input to an enum instance." + assert partial_input in [ + GaussianProcessType.SPARSE_CHOLESKY, + GaussianProcessType.SPARSE_NYSTROEM, + ], "Error matching partial input to one of the SPARSE enums." + + none_input = GaussianProcessType.from_string(None, optional=True) + assert none_input is None, "Error handling None input with optional flag." + + with pytest.raises(ValueError): + GaussianProcessType.from_string( + None + ), "Error was expected with None input without optional flag." + + +def test_compute_nn_distances(): + # Test with different shapes of input + x = jnp.array([[1, 2], [2, 3], [3, 4]]) + expected_output = jnp.array([jnp.sqrt(2), jnp.sqrt(2), jnp.sqrt(2)]) + assert jnp.allclose(compute_nn_distances(x), expected_output) + + # Test with identical instances and save=True + x = jnp.array([[1, 2], [1, 2], [1, 2]]) + with pytest.raises(ValueError): + compute_nn_distances(x) + + # Test with some identical instances and save=True + x = jnp.array([[1, 2], [1, 2], [3, 4]]) + expected_output = jnp.array([jnp.sqrt(8), jnp.sqrt(8), jnp.sqrt(8)]) + assert jnp.allclose(compute_nn_distances(x), expected_output) + + # Test with non-positive distances and save=False + x = jnp.array([[1, 2], [1, 2], [1, 2]]) + expected_output = jnp.array([0, 0, 0]) + assert jnp.allclose(compute_nn_distances(x, save=False), expected_output) + + # Test with varying distances + x = jnp.array([[1, 1], [2, 2], [4, 4], [5, 5]]) + expected_output = jnp.array([jnp.sqrt(2), jnp.sqrt(2), jnp.sqrt(2), jnp.sqrt(2)]) + assert jnp.allclose(compute_nn_distances(x), expected_output) + + # Test with one instance + x = jnp.array([[1, 2]]) + with pytest.raises(ValueError): + compute_nn_distances(x) + + # Test with empty array + x = jnp.array([]) + with pytest.raises(ValueError): + compute_nn_distances(x) + + +def test_compute_gp_type(): + # Test full model with integer rank, float rank, None rank, and zero rank + assert compute_gp_type(0, 100, 100) == GaussianProcessType.FULL + assert compute_gp_type(100, 1.0, 100) == GaussianProcessType.FULL + assert compute_gp_type(100, None, 100) == GaussianProcessType.FULL + assert compute_gp_type(100, 0, 100) == GaussianProcessType.FULL + + # Test full model with Nyström rank reduction + assert compute_gp_type(100, 50, 100) == GaussianProcessType.FULL_NYSTROEM + assert compute_gp_type(100, 0.5, 100) == GaussianProcessType.FULL_NYSTROEM + + # Test sparse model with integer rank, float rank, None rank, and zero rank + assert compute_gp_type(50, 50, 100) == GaussianProcessType.SPARSE_CHOLESKY + assert compute_gp_type(50, 1.0, 100) == GaussianProcessType.SPARSE_CHOLESKY + assert compute_gp_type(50, None, 100) == GaussianProcessType.SPARSE_CHOLESKY + assert compute_gp_type(50, 0, 100) == GaussianProcessType.SPARSE_CHOLESKY + + # Test sparse model with Nyström rank reduction + assert compute_gp_type(50, 25, 100) == GaussianProcessType.SPARSE_NYSTROEM + assert compute_gp_type(50, 0.5, 100) == GaussianProcessType.SPARSE_NYSTROEM def test_compute_mu(): @@ -27,26 +325,53 @@ def test_cov(): assert cov() == test_ls, "cov should produce the expected value." -def test_compute_L(): - def cov(x, y): - return jnp.ones((x.shape[0], y.shape[0])) +def test_compute_rank(): + assert compute_rank(GaussianProcessType.FULL_NYSTROEM) == 0.99 + assert compute_rank(GaussianProcessType.SPARSE_CHOLESKY) == 1.0 + assert compute_rank(None) == 1.0 - n = 2 - d = 2 - X = jnp.ones((n, d)) - L = mellon.parameters.compute_L(X, cov) - assert L.shape[0] == n, "L should have as many rows as there are samples." - L = mellon.parameters.compute_L(X, cov, rank=1.0) - assert L.shape == (n, n), "L should have full rank." - L = mellon.parameters.compute_L(X, cov, rank=1) - assert L.shape == (n, 1), "L should be reduced to rank == 1." - mellon.parameters.compute_L(X, cov, rank=0.5) - mellon.parameters.compute_L(X, cov, landmarks=X) - L = mellon.parameters.compute_L(X, cov, landmarks=X, rank=1.0) - assert L.shape == (n, n), "L should have full rank." - L = mellon.parameters.compute_L(X, cov, landmarks=X, rank=1) - assert L.shape == (n, 1), "L should be reduced to rank == 1." - mellon.parameters.compute_L(X, cov, landmarks=X, rank=0.5) + +def test_compute_n_landmarks(): + DEFAULT_N_LANDMARKS = 5000 + + # Test when landmarks are not None + landmarks = jnp.ones((50, 2)) + assert compute_n_landmarks(None, 100, landmarks) == 50 + + # Test when gp_type is None + assert compute_n_landmarks(None, 100, None) == min(100, DEFAULT_N_LANDMARKS) + + # Test with FULL or FULL_NYSTROEM gp_type + assert compute_n_landmarks(GaussianProcessType.FULL, 100, None) == 100 + assert compute_n_landmarks(GaussianProcessType.FULL_NYSTROEM, 100, None) == 100 + + # Test with SPARSE_CHOLESKY or SPARSE_NYSTROEM gp_type + assert ( + compute_n_landmarks(GaussianProcessType.SPARSE_CHOLESKY, 100, None) + == DEFAULT_N_LANDMARKS + ) + assert ( + compute_n_landmarks(GaussianProcessType.SPARSE_NYSTROEM, 100, None) + == DEFAULT_N_LANDMARKS + ) + + # Test with SPARSE_CHOLESKY or SPARSE_NYSTROEM gp_type and n_samples <= DEFAULT_N_LANDMARKS + assert ( + compute_n_landmarks(GaussianProcessType.SPARSE_CHOLESKY, 80, None) + == DEFAULT_N_LANDMARKS + ) + assert ( + compute_n_landmarks(GaussianProcessType.SPARSE_NYSTROEM, 80, None) + == DEFAULT_N_LANDMARKS + ) + + # Test with unknown gp_type + class UnknownType(Enum): + UNKNOWN = "unknown" + + assert compute_n_landmarks(UnknownType.UNKNOWN, 100, None) == min( + 100, DEFAULT_N_LANDMARKS + ) def test_compute_initial_value(): diff --git a/tests/time_sensitive_density_estimator.py b/tests/time_sensitive_density_estimator.py index 3671d78..1b78471 100644 --- a/tests/time_sensitive_density_estimator.py +++ b/tests/time_sensitive_density_estimator.py @@ -79,21 +79,20 @@ def test_time_sensitive_density_estimator_properties(common_setup_time_sensitive @pytest.mark.parametrize( - "rank, method, n_landmarks, err_limit", + "rank, n_landmarks, err_limit", [ - (1.0, "percent", 10, 2e-1), - (0.99, "percent", 80, 5e-1), + (1.0, 10, 2e-1), + (0.99, 80, 5e-1), ], ) def test_time_sensitive_density_estimator_approximations( - common_setup_time_sensitive, rank, method, n_landmarks, err_limit + common_setup_time_sensitive, rank, n_landmarks, err_limit ): X, times, _, _, relative_err, _, _, _ = common_setup_time_sensitive n = X.shape[0] est = mellon.TimeSensitiveDensityEstimator( rank=rank, - method=method, n_landmarks=n_landmarks, _save_intermediate_ls_times=True, normalize_per_time_point=True, @@ -122,22 +121,20 @@ def test_time_sensitive_density_estimator_approximations( @pytest.mark.parametrize( - "rank, method, n_landmarks, compress", + "rank, n_landmarks, compress", [ - (1.0, "percent", 10, None), - (0.8, "percent", 10, None), - (0.99, "percent", 80, "gzip"), - (0.99, "percent", 80, "bz2"), + (1.0, 10, None), + (0.8, 10, None), + (0.99, 80, "gzip"), + (0.99, 80, "bz2"), ], ) def test_time_sensitive_density_estimator_serialization( - common_setup_time_sensitive, rank, method, n_landmarks, compress + common_setup_time_sensitive, rank, n_landmarks, compress ): X, times, test_file, logger, _, _, _, _ = common_setup_time_sensitive - est = mellon.TimeSensitiveDensityEstimator( - rank=rank, method=method, n_landmarks=n_landmarks - ) + est = mellon.TimeSensitiveDensityEstimator(rank=rank, n_landmarks=n_landmarks) est.fit(X, times) dens_appr = est.predict(X, times) @@ -154,3 +151,126 @@ def test_time_sensitive_density_estimator_serialization( logger.info( "Assertion passed: the deserialized predictor produced the expected results." ) + + +@pytest.mark.parametrize( + "rank, n_landmarks, compress", + [ + (1.0, 10, None), + (0.8, 10, None), + (0.99, 80, "gzip"), + (0.99, 80, "bz2"), + ], +) +def test_density_estimator_serialization_with_uncertainty( + common_setup_time_sensitive, rank, n_landmarks, compress +): + X, times, test_file, logger, _, _, _, _ = common_setup_time_sensitive + n = X.shape[0] + + est = mellon.TimeSensitiveDensityEstimator( + rank=rank, + n_landmarks=n_landmarks, + optimizer="advi", + predictor_with_uncertainty=True, + ) + est.fit(X, times) + dens_appr = est.predict(X, times) + covariance = est.predict.covariance(X, times) + assert covariance.shape == ( + n, + ), "The diagonal of the covariance matrix should be repoorted." + mean_covariance = est.predict.mean_covariance(X, times) + assert mean_covariance.shape == ( + n, + ), "The diagonal of the mean covariance should be repoorted." + uncertainty_pred = est.predict.uncertainty(X, times) + assert uncertainty_pred.shape == (n,), "One value per sample should be reported." + + # Test serialization + est.predict.to_json(test_file, compress=compress) + logger.info( + f"Serialized the predictor with uncertainty and saved it to {test_file}." + ) + predictor = mellon.Predictor.from_json(test_file, compress=compress) + logger.info("Deserialized the predictor with uncertainty from the JSON file.") + reprod = predictor(X, times) + logger.info("Made a prediction with the deserialized predictor.") + is_close = jnp.all(jnp.isclose(dens_appr, reprod)) + assert_msg = "Serialized + deserialized predictor should produce the same results." + assert is_close, assert_msg + reprod_uncertainty = predictor.uncertainty(X, times) + logger.info("Made a uncertainty prediction with the deserialized predictor.") + is_close = jnp.all(jnp.isclose(uncertainty_pred, reprod_uncertainty)) + assert_msg = "Serialized + deserialized predictor should produce the same uncertainty results." + assert is_close, assert_msg + logger.info( + "Assertion passed: the deserialized predictor produced the expected results." + ) + + +def test_density_estimator_errors(common_setup_time_sensitive): + X, times, _, _, _, _, _, _ = common_setup_time_sensitive + Xt = jnp.concatenate([X, times[:, None]], axis=1) + lX = jnp.concatenate( + [ + X, + ] + * 26 + + [ + times[:, None], + ], + axis=1, + ) + est = mellon.TimeSensitiveDensityEstimator() + + with pytest.raises(ValueError): + est.fit_predict() + with pytest.raises(ValueError): + est.fit(None) + est.set_x(Xt) + with pytest.raises(ValueError): + est.prepare_inference(lX) + loss_func, initial_value = est.prepare_inference(None) + est.run_inference(loss_func, initial_value, "advi") + est.process_inference(est.pre_transformation) + with pytest.raises(ValueError): + est.predict(X[:, :-1], times) + with pytest.raises(ValueError): + est.fit_predict(lX) + est.fit_predict() + est.predict.n_obs = None + with pytest.raises(ValueError): + est.predict(X, time=times, normalize=True) + + +@pytest.mark.parametrize( + "normalization, different", + [ + (False, False), + (True, False), + ([4, 4, 1000, 4], True), + (jnp.array([4, 4, 1000, 4]), True), + ({1: 4, 0: 4, 2: 1000, 3: 4}, True), + ], +) +def test_time_sensitive_density_estimator_normalizations( + common_setup_time_sensitive, normalization, different +): + X, times, _, _, relative_err, _, _, _ = common_setup_time_sensitive + err_limit = 1e-4 + min_diff = 1e-1 + + est = mellon.TimeSensitiveDensityEstimator( + normalize_per_time_point=normalization, + ) + est.fit(X, times) + dens_appr = est.predict(X, times) + if different: + assert ( + relative_err(dens_appr) > min_diff + ), "This normalization should be different the default." + else: + assert ( + relative_err(dens_appr) < err_limit + ), "This normalization should be close to the default." diff --git a/tests/util.py b/tests/util.py index ab3ac64..2316375 100644 --- a/tests/util.py +++ b/tests/util.py @@ -49,3 +49,9 @@ def test_Log(): assert hasattr(logger, "info") assert hasattr(logger, "warn") assert hasattr(logger, "error") + mellon.util.Log.off() + mellon.util.Log.on() + + +def test_set_jax_config(): + mellon.util.set_jax_config() diff --git a/tests/validation.py b/tests/validation.py index dfe8f46..eb3c047 100644 --- a/tests/validation.py +++ b/tests/validation.py @@ -10,11 +10,57 @@ _validate_string, _validate_float_or_iterable_numerical, _validate_positive_int, + _validate_1d, +) +from mellon.parameter_validation import ( + _validate_params, _validate_cov_func_curry, _validate_cov_func, - _validate_1d, ) from mellon.cov import Covariance +from mellon.util import GaussianProcessType + + +@pytest.mark.parametrize( + "rank, gp_type, n_samples, n_landmarks, landmarks, exception_expected", + [ + # Test valid cases + (1.0, GaussianProcessType.FULL, 100, 0, None, None), + (0.5, GaussianProcessType.FULL_NYSTROEM, 100, 100, None, None), + (1.0, GaussianProcessType.SPARSE_CHOLESKY, 100, 50, None, None), + (0.5, GaussianProcessType.SPARSE_NYSTROEM, 100, 50, jnp.zeros((50, 5)), None), + # Test error for invalid rank + (None, GaussianProcessType.FULL, 100, 0, None, ValueError), + ("some_type", GaussianProcessType.FULL, 100, 50, None, ValueError), + (0.9, GaussianProcessType.FULL, 100, 0, None, ValueError), + # Test error for invalid gp_type (not a GaussianProcessType instance) + (1.0, "some_type", 100, 0, None, ValueError), + # Test error cases for landmarks + ( + 0.5, + GaussianProcessType.SPARSE_NYSTROEM, + 100, + 51, + jnp.zeros((50, 5)), + ValueError, + ), + (None, GaussianProcessType.FULL, 100, 50, jnp.zeros((60, 5)), ValueError), + (None, GaussianProcessType.FULL_NYSTROEM, 100, 50, None, ValueError), + (0.5, GaussianProcessType.SPARSE_CHOLESKY, 100, 0, None, ValueError), + (1.0, GaussianProcessType.FULL, 100, 10, None, ValueError), + (0, GaussianProcessType.SPARSE_NYSTROEM, 100, 100, None, ValueError), + (2.0, GaussianProcessType.FULL_NYSTROEM, 100, 0, None, ValueError), + (100, GaussianProcessType.SPARSE_NYSTROEM, 100, 50, None, ValueError), + ], +) +def test_validate_params( + rank, gp_type, n_samples, n_landmarks, landmarks, exception_expected +): + if exception_expected: + with pytest.raises(exception_expected): + _validate_params(rank, gp_type, n_samples, n_landmarks, landmarks) + else: + _validate_params(rank, gp_type, n_samples, n_landmarks, landmarks) def test_validate_float_or_int(): @@ -128,21 +174,49 @@ def test_validate_string(): def test_validate_float_or_iterable_numerical(): - # Test with float input - assert _validate_float_or_iterable_numerical(10.5, "param") == 10.5 + # Test with positive numbers + assert _validate_float_or_iterable_numerical(5, "value") == 5.0 + assert jnp.allclose( + _validate_float_or_iterable_numerical([5, 6], "value"), jnp.asarray([5.0, 6.0]) + ) + + # Test with negative numbers, without positive constraint + assert _validate_float_or_iterable_numerical(-5, "value") == -5.0 + assert jnp.allclose( + _validate_float_or_iterable_numerical([-5, -6], "value"), + jnp.asarray([-5.0, -6.0]), + ) + + # Test with zero + assert _validate_float_or_iterable_numerical(0, "value") == 0.0 + + # Test with positive=True + assert _validate_float_or_iterable_numerical(5, "value", positive=True) == 5.0 + + # Test with negative numbers and positive=True + with pytest.raises(ValueError): + _validate_float_or_iterable_numerical(-5, "value", positive=True) - # Test with iterable input - array = jnp.array([1, 2, 3]) - validated_array = _validate_float_or_iterable_numerical(array, "param") - assert jnp.array_equal(validated_array, array) + with pytest.raises(ValueError): + _validate_float_or_iterable_numerical([-5, 6], "value", positive=True) + + # Test with None and optional=True + assert _validate_float_or_iterable_numerical(None, "value", optional=True) is None + + # Test with None and optional=False + with pytest.raises(TypeError): + _validate_float_or_iterable_numerical(None, "value", optional=False) + + # Test with non-numeric types + with pytest.raises(TypeError): + _validate_float_or_iterable_numerical("string", "value") - # Test with non-numeric iterable input with pytest.raises(ValueError): - _validate_float_or_iterable_numerical(["invalid", "input"], "param") + _validate_float_or_iterable_numerical(["string"], "value") - # Test with non-numeric, non-iterable input + # Test with mixed numeric and non-numeric iterable with pytest.raises(ValueError): - _validate_float_or_iterable_numerical("invalid", "param") + _validate_float_or_iterable_numerical([5, "string"], "value") def test_validate_time_x():