Skip to content

Commit

Permalink
Predictor tab in sphinx doc
Browse files Browse the repository at this point in the history
  • Loading branch information
katosh committed Aug 30, 2023
1 parent fef82a9 commit 7107cd0
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 46 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
:caption: Modules:

Model <model>
Predictor <predictor>
Serialization <serialization>
Utilities <util>
Covariance Functions <cov>
Expand Down
114 changes: 114 additions & 0 deletions docs/source/predictor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
Predictors
==========

Predictors in the Mellon framework can be invoked directly via their `__call__`
method to produce function estimates at new locations. These predictors can
also double as Gaussian Processes, offering uncertainty estimattion options.
It also comes with serialization capabilities detailed in :ref:`serialization <serialization>`.

Basic Usage
-----------

To generate estimates for new, out-of-sample locations, instantiate a
predictor and call it like a function:

.. code-block:: python
:caption: Example of accessing the :class:`mellon.Predictor` from the
:class:`mellon.model.DensityEstimator` in Mellon Framework
:name: example-usage-density-predictor
model = mellon.model.DensityEstimator(...) # Initialize the model with appropriate arguments
model.fit(X) # Fit the model to the data
predictor = model.predict # Obtain the predictor object
predicted_values = predictor(Xnew) # Generate predictions for new locations
Uncertainy
------------

If the predictor was generated with
uncertainty estimates (typically by passing `predictor_with_uncertainty=True`
and `optimizer="advi"` to the model class, e.g., :class:`mellon.model.DensityEstimator`)
then it provides methods for computing variance at these locations, and co-variance to any other
location.

- Variance Methods:
- :meth:`mellon.Predictor.covariance`
- :meth:`mellon.Predictor.mean_covariance`
- :meth:`mellon.Predictor.uncertainty`

Classifications
---------------

The `Predictor` module contains various specialized subclasses of
:class:`mellon.Predictor`. The subclass returned by the model depends on:

- The `gp_type` argument, which specifies the Gaussian Process type.
- The nature of the predicted value—be it real-valued, positive, or time-sensitive.

While the `gp_type` impacts the internat iplementation the nature of the
prediction impacts its functionality leading to these subclasses

- **Real-valued**: :class:`mellon.Predictor`
- **Positive-valued**: :class:`mellon.base_predictor.ExpPredictor`
- **Time-sensitive**: :class:`mellon.base_predictor.PredictorTime`

Classifications of Predictors
-----------------------------

The `Predictor` module in the Mellon framework features a variety of
specialized subclasses of :class:`mellon.Predictor`. The specific subclass
instantiated by the model is contingent upon two key parameters:

- `gp_type`: This argument determines the type of Gaussian Process used internally.
- The nature of the predicted output: This can be real-valued, strictly positive, or time-sensitive.

The `gp_type` argument mainly affects the internal mathematical operations,
whereas the nature of the predicted value dictates the subclass's functional
capabilities:

- **Real-valued Predictions**: Such as log-density estimates, :class:`mellon.Predictor`.
- **Positive-valued Predictions**: Such as dimensionality estimates, :class:`mellon.base_predictor.ExpPredictor`.
- **Time-sensitive Predictions**: Such as time-sensitive density estimates :class:`mellon.base_predictor.PredictorTime`.



Vanilla Predictor
-----------------

Utilized in the following methods:

- :meth:`mellon.model.DensityEstimator.predict`
- :meth:`mellon.model.DimensionalityEstimator.predict_density`
- :meth:`mellon.model.FunctionEstimator.predict`

.. autoclass:: mellon.Predictor
:members:
:undoc-members:
:show-inheritance:
:exclude-members: n_obs, n_input_features

Exponential Predictor
---------------------

- Used in :meth:`mellon.model.DimensionalityEstimator.predict`
- Predicted values are strictly positive. Variance is expressed in log scale.

.. autoclass:: mellon.base_predictor.ExpPredictor
:members:
:undoc-members:
:show-inheritance:
:exclude-members: n_obs, n_input_features

Time-sensitive Predictor
------------------------

- Utilized in :meth:`mellon.model.TimeSensitiveDensityEstimator.predict`
- Special arguments `time` and `multi_time` permit time-specific predictions.

.. autoclass:: mellon.base_predictor.PredictorTime
:members:
:undoc-members:
:show-inheritance:
:exclude-members: n_obs, n_input_features

39 changes: 24 additions & 15 deletions docs/source/serialization.rst
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
.. _serialization:

Serialization
=============

The Mellon module provides a comprehensive suite of tools for serializing and deserializing estimators. This functionality is particularly useful in computational biology where models might be trained on one dataset and then used for making predictions on new datasets. The ability to serialize an estimator allows you to save the state of a model after it has been trained, and then load it later for predictions without needing to retrain the model.

For instance, you might have a large dataset on which you train an estimator. Training might take a long time due to the size and complexity of the data. With serialization, you can save the trained estimator and then easily share it with collaborators, apply it to new data, or use it in a follow-up study.

When an estimator is serialized, it includes a variety of metadata, including the serialization date, Python version, the class name of the estimator, and the Mellon version.
This metadata serves as a detailed record of the state of your environment at the time of serialization, which can be crucial for reproducibility in scientific research and for understanding the conditions under which the estimator was originally trained.
The Mellon module facilitates the serialization and deserialization of
predictors, streamlining model transfer and reuse. This is especially relevant
in computational biology for applying pre-trained models to new datasets.
Serialization captures essential metadata like date, Python version, estimator
class, and Mellon version, ensuring research reproducibility and context for
the original training.

All estimators in Mellon that inherit from the :class:`BaseEstimator` class, including :class:`mellon.model.DensityEstimator`, :class:`mellon.model.FunctionEstimator`, and :class:`mellon.model.DimensionalityEstimator`, have serialization capabilities.
After fitting data, all Mellon models generate a predictor via their `.predict`
property, including model classes like :class:`mellon.model.DensityEstimator`,
:class:`mellon.model.TimeSensitiveDensityEstimator`,
:class:`mellon.model.FunctionEstimator`, and
:class:`mellon.model.DimensionalityEstimator`.


Predictor Class
---------------

The `Predictor` class, accessible through the `predict` property of an estimator, handles the serialization and deserialization process.

.. autoclass:: mellon.Predictor
:members:
:undoc-members:
:show-inheritance:
:exclude-members: n_obs, n_input_features
All predictors inherit their serialization methods from :class:`mellon.Predictor`.

Serialization to AnnData
------------------------

Estimators can be serialized to an `AnnData`_ object. The `log_density` computation for the AnnData object shown below is a simplified example. For a more comprehensive guide on how to properly preprocess data to compute the log density, refer to our
Predictors can be serialized to an `AnnData`_ object. The `log_density`
computation for the AnnData object shown below is a simplified example. For a
more comprehensive guide on how to properly preprocess data to compute the log
density, refer to our
`basic tutorial notebook <https://github.com/settylab/Mellon/blob/main/notebooks/basic_tutorial.ipynb>`_.


Expand All @@ -49,6 +52,9 @@ Estimators can be serialized to an `AnnData`_ object. The `log_density` computat
Deserialization from AnnData
----------------------------

The function `mellon.Predictor.from_dict` can deserialize the
:class:`mellon.Predictor` and any sub class.

.. code-block:: python
# Load the AnnData object
Expand All @@ -65,6 +71,9 @@ Serialization to File

Mellon supports serialization to a human-readable JSON file and compressed file formats such as .gz (gzip) and .bz2 (bzip2).

The function `mellon.Predictor.from_json` can deserialize the
:class:`mellon.Predictor` and any sub class.

.. code-block:: python
# Serialization to JSON
Expand Down
34 changes: 14 additions & 20 deletions mellon/dimensionality_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,18 +509,16 @@ def fit(self, x=None, build_predict=True):
@property
def predict_density(self):
"""
Predicts the log density with an adaptive unit for each data point in `x`.
The unit of density depends on the dimensionality of the data.
A property that returns an instance of the :class:`mellon.Predictor` class. This predictor can
be used to predict the log density for new data points by calling the instance like a function.
Parameters
----------
x : array-like of shape (n_samples, n_features)
The new data for which to predict the log density.
The predictor instance also supports serialization features, which allow for saving and loading
the predictor's state. For more details, refer to the :class:`mellon.Predictor` documentation.
Returns
-------
array-like
The predicted log density for each test point in `x`.
mellon.Predictor
A predictor instance that computes the log density at each new data point.
Example
-------
Expand All @@ -535,22 +533,18 @@ def predict_density(self):
@property
def predict(self):
"""
Returns an instance of the :class:`mellon.Predictor` class, which predicts the dimensionality
at each point in `x`.
This instance includes a __call__ method, which can be used to predict the dimensionality.
The instance also supports serialization features, allowing for saving and loading the predictor's
state. For more details, refer to :class:`mellon.Predictor`.
A property that returns an instance of the :class:`mellon.base_predictor.ExpPredictor` class.
This predictor can be used to predict the dimensionality for new data points by calling
the instance like a function.
Parameters
----------
x : array-like of shape (n_samples, n_features)
The new data for which to predict the dimensionality.
The predictor instance also supports serialization features, which allow for saving and loading
the predictor's state. For more details, refer to the :class:`mellon.base_predictor.ExpPredictor`
documentation.
Returns
-------
array-like
The predicted dimensionality for each test point in `x`.
mellon.base_predictor.ExpPredictor
A predictor instance that computes the dimensionality at each new data point.
Example
-------
Expand Down
17 changes: 6 additions & 11 deletions mellon/time_sensitive_density_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,23 +652,18 @@ def fit(self, x=None, times=None, build_predict=True):
@property
def predict(self):
R"""
An instance of the :class:`mellon.Predictor` that predicts the log density at each point in x.
An instance of the :class:`mellon.base_predictor.PredictorTime` that predicts the
log density at each point in x and time point time.
The instance contains a __call__ method which can be used to predict the log density.
This instance also supports serialization features which allows for saving
and loading the predictor state. Refer to mellon.Predictor documentation for more details.
Note that the last column of the input array `x` should contain the time information.
Parameters
----------
x : array-like
The new data to predict, where the last column should contain the time information.
and loading the predictor state. Refer to :class:`mellon.base_predictor.PredictorTime`
documentation for more details.
Returns
-------
log_density : array-like
The log density at each test point in `x`.
mellon.base_predictor.PredictorTime
A predictor instance that computes the log density at each new data point.
Example
-------
Expand Down
4 changes: 4 additions & 0 deletions tests/dimensionality_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def test_dimensionality_estimator_serialization_with_uncertainty(
est.fit(X)
dens_appr = est.predict(X)
log_dens_appr = est.predict(X, logscale=True)
assert (
is_close
), "The exp of the log scale prediction should mix the original prediction."
is_close = jnp.all(jnp.isclose(dens_appr, jnp.exp(log_dens_appr)))
covariance = est.predict.covariance(X)
assert covariance.shape == (
n,
Expand Down

0 comments on commit 7107cd0

Please sign in to comment.