Skip to content

Commit

Permalink
Merge pull request #7 from settylab/uncertainty
Browse files Browse the repository at this point in the history
this merge commit is the release candidate for Mellon v1.4.0
  • Loading branch information
katosh authored Oct 3, 2023
2 parents 3998469 + 924aa5a commit 64af1ce
Show file tree
Hide file tree
Showing 38 changed files with 4,274 additions and 983 deletions.
87 changes: 87 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# v1.4.0
## New Features
### `with_uncertainty` Parameter
Integrates a boolean parameter `with_uncertainty` across all estimators: [DensityEstimator](https://mellon.readthedocs.io/en/uncertainty/model.html#mellon.model.DensityEstimator), TimeSensitiveDensityEstimator, FunctionEstimator, and DimensionalityEstimator. It modifies the fitted predictor, accessible via the `.predict` property, to include the following methods:
- `.covariance(X)`: Calculates the (co-)variance of the posterior Gaussian Process (GP).
- Almost 0 near landmarks; grows for out-of-sample locations.
- Increases with sparsity.
- Defaults to `diag=True`, computing only the covariance matrix diagonal.
- `.mean_covariance(X)`: Computes the (co-)variance through the uncertainty of the mean function's GP posterior.
- Derived from Bayesian inference for latent density function representation.
- Increases in low data or low-density areas.
- Only available with posterior uncertainty quantification, e.g., `optimizer='advi'` except for the `FunctionEstimator` where input uncertainty is specified through the `sigma` parameter.
- Defaults to `diag=True`, computing only the covariance matrix diagonal.
- `.uncertainty(X)`: Combines `.covariance(X)` and `.mean_covariance(X)`.
- Defaults to `diag=True`, computing only the covariance matrix diagonal.
- Square root provides standard deviation.

### `gp_type` Parameter
Introduces the `gp_type` parameter to all relevant [estimators](https://mellon.readthedocs.io/en/uncertainty/model.html) to explicitly specify the Gaussian Process (GP) sparsification strategy, replacing the previously used `method` argument (with options auto, fixed, and percent) that implicitly controlled sparsification. The available options for `gp_type` include:
- 'full': Non-sparse GP.
- 'full_nystroem': Sparse GP with Nyström rank reduction, lowering computational complexity.
- 'sparse_cholesky': SParse GP using landmarks/inducing points.
- 'sparse_nystroem': Improved Nyström rank reduction on sparse GP with landmarks, balancing accuracy and efficiency.

This new parameter adds additional validation steps, ensuring that no contradictory parameters are specified. If inconsistencies are detected, a helpful reply guides the user on how to fix the issue. The value can be either a string matching one of the options above or an instance of the `mellon.parameters.GaussianProcessType` Enum. Partial matches log a warning, using the closest match. Defaults to 'sparse_cholesky'.

*Note: Nyström strategies are not applicable to the **FunctionEstimator**.*

### `y_is_mean` Parameter
Adds a boolean parameter `y_is_mean` to [FunctionEstimator](https://mellon.readthedocs.io/en/uncertainty/model.html#mellon.model.FunctionEstimator), affecting how `y` values are interpreted:
- **Old Behavior**: `sigma` impacted conditional mean functions and predictions.
- **Intermediate Behavior**: `sigma` only influenced prediction uncertainty.
- **New Parameter**: If `y_is_mean=True`, `y` values are treated as a fixed mean; `sigma` reflects only uncertainty. If `y_is_mean=False`, `y` is considered a noisy measurement, potentially smoothing values at locations `x`.

This change benefits DensityEstimator, TimeSensitiveDensityEstimator, and DimensionalityEstimator where function values are predicted for out-of-sample locations after mean GP computation.

### `check_rank` Parameter
Introduces the `check_rank ` parameter to all relevant [estimators](https://mellon.readthedocs.io/en/uncertainty/model.html). This boolean parameter explicitly controls whether the rank check is performed, specifically in the `gp_type="sparse_cholesky"` case. The rank check assesses the chosen landmarks for adequate complexity by examining the approximate rank of the covariance matrix, issuing a warning if insufficient. Allowed values are:
- `True`: Always perform the check.
- `False`: Never perform the check.
- `None` (Default): Perform the check only if `n_landmarks` is greater than or equal to `n_samples` divided by 10.

The default setting aims to bypass unnecessary computation when the number of landmarks is so abundant that insufficient complexity becomes improbable.

### `normalize` Parameter

The `normalize` parameter is applicable to both the [`.mean`](https://mellon.readthedocs.io/en/uncertainty/serialization.html#mellon.Predictor.mean) method and `.__call__` method within the [mellon.Predictor](https://mellon.readthedocs.io/en/uncertainty/serialization.html#predictor-class) class. When set to `True`, these methods will subtract `log(number of observations)` from the value returned. This feature is particularly useful with the [DensityEstimator](https://mellon.readthedocs.io/en/uncertainty/model.html#mellon.model.DensityEstimator), where normalization adjusts for the number of cells in the training sample, allowing for accurate density comparisons between datasets. This correction takes into account the effect of dataset size, ensuring that differences in total cell numbers are not unduly influential. By default, the parameter is set to `False`, meaning that density differences due to variations in total cell number will remain uncorrected.

### `normalize_per_time_point` Parameter

This parameter fine-tunes the `TimeSensitiveDensityEstimator` to handle variations in sampling bias across different time points, ensuring both continuity and differentiability in the resulting density estimation. Notably, it also allows to reflect the growth of a population even if the same number of cells were sampled from each time point.

The normalization is realized by manipulating the nearest neighbor distances
`nn_distances` to reflect the deviation from an expected cell count.

- **Type**: Optional, accepts `bool`, `list`, `array-like`, or `dict`.

#### Options:

- **`True`:** Normalizes to emulate an even distribution of total cell count across all time points.
- **`False`:** Retains raw cell counts at each time point for density estimation.
- **List/Array-like**: Specifies an ordered sequence of total cell count targets for each time point, starting with the earliest.
- **Dict**: Associates each unique time point with a specific total cell count target.

#### Notes:

- **Relative Metrics**: While this parameter adjusts for sample bias, it only requires relative cell counts for comparisons within the dataset; exact counts are not mandatory.
- **`nn_distance` Precedence**: If `nn_distance` is supplied, this parameter will be bypassed, and the provided distances will be used directly.
- The default value is `False`


## Enhancements
- Optimization by saving the intermediate result `Lp` in the estimators for reuse, enhancing the speed of the predictive function computation in non-Nyström strategies.
- The `DimensionalityEstimator.predict` now returns a subclass of the `mellon.Predictor` class instead of a closure. Giving access to serialization and uncertainty computations.
- Expanded testing.
- propagate logging messages and explicit logger name "mellon" everywhere
- extended parameter validation for the estimators now also applies to the `compute_L` function
- better string representation of estimators and predictors
- bugfix some edge cases
- Revise some documentation (s. b70bb04a4e921ceab63b60026b8033e384a8916a) and include [Predictor](https://mellon.readthedocs.io/en/uncertainty/predictor.html) page on sphinx doc

## Changes

- The mellon.Predictor class now has a method `.mean` that is an alias to `.__call__`.
- All mellon.Predictor sub classes `...ConditionalMean...` were renamed to `...Conditional...` since they now also compute `.covariance` and `.mean_covariance`.
- All generating methods for mellon.Predictor were renamed from `...conditional_mean...` to `conditional`.
- A new log message now informs that the normalization is not effective `d_method != "fractal"`. Additionally, using `normalize=True` in the density predictor triggers a warning that one has to use the non default `d_method = "fractal"` in the `DensityEstimator`.
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ Mellon
:target: https://zenodo.org/badge/latestdoi/558998366
.. image:: https://codecov.io/github/settylab/Mellon/branch/main/graph/badge.svg?token=TKIKXK4MPG
:target: https://app.codecov.io/github/settylab/Mellon
.. image:: https://www.codefactor.io/repository/github/settylab/mellon/badge/main
:target: https://www.codefactor.io/repository/github/settylab/mellon/overview/main
:alt: CodeFactor
.. image:: https://badge.fury.io/py/mellon.svg
:target: https://badge.fury.io/py/mellon
.. image:: https://anaconda.org/conda-forge/mellon/badges/version.svg
Expand Down
3 changes: 2 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def get_version(rel_path):
"sphinx.ext.autodoc",
"nbsphinx",
"sphinx.ext.napoleon",
"sphinx_github_style",
]
if os.environ.get('READTHEDOCS') == 'True':
extensions.append("sphinx_github_style")

source_suffix = [".rst", ".md"]

Expand Down
3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
:caption: Modules:

Model <model>
Predictor <predictor>
Serialization <serialization>
Utilities <util>
Covariance Functions <cov>
Expand All @@ -21,7 +22,7 @@
notebooks/trajectory-trends_tutorial.ipynb
notebooks/gene_change_analysis_tutorial.ipynb
notebooks/time-series_tutorial.ipynb

.. include:: ../../README.rst

.. toctree::
Expand Down
98 changes: 98 additions & 0 deletions docs/source/predictor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
Predictors
==========

Predictors in the Mellon framework can be invoked directly via their `__call__`
method to produce function estimates at new locations. These predictors can
also double as Gaussian Processes, offering uncertainty estimattion options.
It also comes with serialization capabilities detailed in :ref:`serialization <serialization>`.

Basic Usage
-----------

To generate estimates for new, out-of-sample locations, instantiate a
predictor and call it like a function:

.. code-block:: python
:caption: Example of accessing the :class:`mellon.Predictor` from the
:class:`mellon.model.DensityEstimator` in Mellon Framework
:name: example-usage-density-predictor
model = mellon.model.DensityEstimator(...) # Initialize the model with appropriate arguments
model.fit(X) # Fit the model to the data
predictor = model.predict # Obtain the predictor object
predicted_values = predictor(Xnew) # Generate predictions for new locations
Uncertainy
------------

If the predictor was generated with
uncertainty estimates (typically by passing `predictor_with_uncertainty=True`
and `optimizer="advi"` to the model class, e.g., :class:`mellon.model.DensityEstimator`)
then it provides methods for computing variance at these locations, and co-variance to any other
location.

- Variance Methods:
- :meth:`mellon.Predictor.covariance`
- :meth:`mellon.Predictor.mean_covariance`
- :meth:`mellon.Predictor.uncertainty`

Sub-Classes
-----------

The `Predictor` module in the Mellon framework features a variety of
specialized subclasses of :class:`mellon.Predictor`. The specific subclass
instantiated by the model is contingent upon two key parameters:

- `gp_type`: This argument determines the type of Gaussian Process used internally.
- The nature of the predicted output: This can be real-valued, strictly positive, or time-sensitive.

The `gp_type` argument mainly affects the internal mathematical operations,
whereas the nature of the predicted value dictates the subclass's functional
capabilities:

- **Real-valued Predictions**: Such as log-density estimates, :class:`mellon.Predictor`.
- **Positive-valued Predictions**: Such as dimensionality estimates, :class:`mellon.base_predictor.ExpPredictor`.
- **Time-sensitive Predictions**: Such as time-sensitive density estimates :class:`mellon.base_predictor.PredictorTime`.



Vanilla Predictor
-----------------

Utilized in the following methods:

- :attr:`mellon.model.DensityEstimator.predict`
- :attr:`mellon.model.DimensionalityEstimator.predict_density`
- :attr:`mellon.model.FunctionEstimator.predict`

.. autoclass:: mellon.Predictor
:members:
:undoc-members:
:show-inheritance:
:exclude-members: n_obs, n_input_features

Exponential Predictor
---------------------

- Used in :attr:`mellon.model.DimensionalityEstimator.predict`
- Predicted values are strictly positive. Variance is expressed in log scale.

.. autoclass:: mellon.base_predictor.ExpPredictor
:members:
:undoc-members:
:show-inheritance:
:exclude-members: n_obs, n_input_features

Time-sensitive Predictor
------------------------

- Utilized in :attr:`mellon.model.TimeSensitiveDensityEstimator.predict`
- Special arguments `time` and `multi_time` permit time-specific predictions.

.. autoclass:: mellon.base_predictor.PredictorTime
:members:
:undoc-members:
:show-inheritance:
:exclude-members: n_obs, n_input_features

44 changes: 28 additions & 16 deletions docs/source/serialization.rst
Original file line number Diff line number Diff line change
@@ -1,30 +1,34 @@
.. _serialization:

Serialization
=============

The Mellon module provides a comprehensive suite of tools for serializing and deserializing estimators. This functionality is particularly useful in computational biology where models might be trained on one dataset and then used for making predictions on new datasets. The ability to serialize an estimator allows you to save the state of a model after it has been trained, and then load it later for predictions without needing to retrain the model.

For instance, you might have a large dataset on which you train an estimator. Training might take a long time due to the size and complexity of the data. With serialization, you can save the trained estimator and then easily share it with collaborators, apply it to new data, or use it in a follow-up study.

When an estimator is serialized, it includes a variety of metadata, including the serialization date, Python version, the class name of the estimator, and the Mellon version.
This metadata serves as a detailed record of the state of your environment at the time of serialization, which can be crucial for reproducibility in scientific research and for understanding the conditions under which the estimator was originally trained.
The Mellon module facilitates the serialization and deserialization of
predictors, streamlining model transfer and reuse. This is especially relevant
in computational biology for applying pre-trained models to new datasets.
Serialization captures essential metadata like date, Python version, estimator
class, and Mellon version, ensuring research reproducibility and context for
the original training.

All estimators in Mellon that inherit from the :class:`BaseEstimator` class, including :class:`mellon.model.DensityEstimator`, :class:`mellon.model.FunctionEstimator`, and :class:`mellon.model.DimensionalityEstimator`, have serialization capabilities.
After fitting data, all Mellon models generate a predictor via their `.predict`
property, including model classes like :class:`mellon.model.DensityEstimator`,
:class:`mellon.model.TimeSensitiveDensityEstimator`,
:class:`mellon.model.FunctionEstimator`, and
:class:`mellon.model.DimensionalityEstimator`.


Predictor Class
---------------

The `Predictor` class, accessible through the `predict` property of an estimator, handles the serialization and deserialization process.

.. autoclass:: mellon.Predictor
:members:
:undoc-members:
:show-inheritance:
All predictors inherit their serialization methods from :class:`mellon.Predictor`.

Serialization to AnnData
------------------------

Estimators can be serialized to an `AnnData`_ object. The `log_density` computation for the AnnData object shown below is a simplified example. For a more comprehensive guide on how to properly preprocess data to compute the log density, refer to our
Predictors can be serialized to an `AnnData`_ object. The `log_density`
computation for the AnnData object shown below is a simplified example. For a
more comprehensive guide on how to properly preprocess data to compute the log
density, refer to our
`basic tutorial notebook <https://github.com/settylab/Mellon/blob/main/notebooks/basic_tutorial.ipynb>`_.


Expand All @@ -48,6 +52,9 @@ Estimators can be serialized to an `AnnData`_ object. The `log_density` computat
Deserialization from AnnData
----------------------------

The function :meth:`mellon.Predictor.from_dict` can deserialize the
:class:`mellon.Predictor` and any sub class.

.. code-block:: python
# Load the AnnData object
Expand All @@ -62,7 +69,11 @@ Deserialization from AnnData
Serialization to File
---------------------

Mellon supports serialization to a human-readable JSON file and compressed file formats such as .gz (gzip) and .bz2 (bzip2).
Mellon supports serialization to a human-readable JSON file and compressed file
formats such as .gz (gzip) and .bz2 (bzip2).

The function :meth:`mellon.Predictor.from_json` can deserialize the
:class:`mellon.Predictor` and any sub class.

.. code-block:: python
Expand All @@ -78,7 +89,8 @@ Mellon supports serialization to a human-readable JSON file and compressed file
Deserialization from File
-------------------------

Mellon supports deserialization from JSON and compressed file formats. The compression method can be inferred from the file extension.
Mellon supports deserialization from JSON and compressed file formats. The
compression method can be inferred from the file extension.

.. code-block:: python
Expand Down
4 changes: 3 additions & 1 deletion mellon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from . import _conditional as conditional
from . import _derivatives as derivatives

__version__ = "1.3.1"
__version__ = "1.4.0"

__all__ = [
"DensityEstimator",
Expand All @@ -37,6 +37,8 @@
"derivatives",
"__version__",
]
# Set up logger
Log()

# Set default configuration at import time
jaxconfig.update("jax_enable_x64", True)
Expand Down
30 changes: 18 additions & 12 deletions mellon/_conditional.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from .conditional import (
FullConditionalMean,
FullConditionalMeanTime,
LandmarksConditionalMean,
LandmarksConditionalMeanTime,
LandmarksConditionalMeanCholesky,
LandmarksConditionalMeanCholeskyTime,
FullConditional,
ExpFullConditional,
FullConditionalTime,
LandmarksConditional,
ExpLandmarksConditional,
LandmarksConditionalTime,
LandmarksConditionalCholesky,
ExpLandmarksConditionalCholesky,
LandmarksConditionalCholeskyTime,
)

__all__ = [
"FullConditionalMean",
"FullConditionalMeanTime",
"LandmarksConditionalMean",
"LandmarksConditionalMeanTime",
"LandmarksConditionalMeanCholesky",
"LandmarksConditionalMeanCholeskyTime",
"FullConditional",
"ExpFullConditional",
"FullConditionalTime",
"LandmarksConditional",
"ExpLandmarksConditional",
"LandmarksConditionalTime",
"LandmarksConditionalCholesky",
"ExpLandmarksConditionalCholesky",
"LandmarksConditionalCholeskyTime",
]
14 changes: 8 additions & 6 deletions mellon/_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
minimize_adam,
minimize_lbfgsb,
compute_log_density_x,
compute_conditional_mean,
compute_conditional_mean_times,
compute_conditional_mean_explog,
compute_parameter_cov_factor,
compute_conditional,
compute_conditional_times,
compute_conditional_explog,
generate_gaussian_sample,
calculate_gaussian_logpdf,
calculate_elbo,
Expand All @@ -24,9 +25,10 @@
"minimize_adam",
"minimize_lbfgsb",
"compute_log_density_x",
"compute_conditional_mean",
"compute_conditional_mean_times",
"compute_conditional_mean_explog",
"compute_parameter_cov_factor",
"compute_conditional",
"compute_conditional_times",
"compute_conditional_explog",
"generate_gaussian_sample",
"calculate_gaussian_logpdf",
"calculate_elbo",
Expand Down
Loading

0 comments on commit 64af1ce

Please sign in to comment.