diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index 69897c3..23a5340 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -7,7 +7,7 @@ build:
sphinx:
configuration: docs/conf.py
# disable this for more lenient docs builds
- fail_on_warning: true
+ fail_on_warning: false
python:
install:
- method: pip
diff --git a/docs/api.md b/docs/api.md
index 0484c14..2f2b388 100644
--- a/docs/api.md
+++ b/docs/api.md
@@ -1,38 +1,25 @@
# API
-## Preprocessing
+## Bayesian Models
```{eval-rst}
-.. module:: torchgmm.pp
+.. module:: torchgmm.bayes
.. currentmodule:: torchgmm
.. autosummary::
:toctree: generated
- pp.basic_preproc
+ bayes.GaussianMixture
```
-## Tools
+## Clustering Models
```{eval-rst}
-.. module:: torchgmm.tl
+.. module:: torchgmm.clustering
.. currentmodule:: torchgmm
.. autosummary::
:toctree: generated
- tl.basic_tool
-```
-
-## Plotting
-
-```{eval-rst}
-.. module:: torchgmm.pl
-.. currentmodule:: torchgmm
-
-.. autosummary::
- :toctree: generated
-
- pl.basic_plot
- pl.BasicClass
+ clustering.KMeans
```
diff --git a/docs/benchmark.md b/docs/benchmark.md
new file mode 100644
index 0000000..c522b6f
--- /dev/null
+++ b/docs/benchmark.md
@@ -0,0 +1,90 @@
+# Benchmarks
+
+This benchmark is based on PyCave's benchmarking, where they evaluated the runtime performance of TorchGMM by running an exhaustive set of experiments to compare against the implementation found in scikit-learn. Evaluations are run at varying dataset sizes.
+
+All benchmarks are run on an instance with a Intel Xeon E5-2630 v4 CPU (2.2 GHz). They use at most 4
+cores and 60 GiB of memory. Also, there is a single GeForce GTX 1080 Ti GPU (11 GiB memory)
+available. For the performance measures, each benchmark is run at least 5 times.
+
+## Gaussian Mixture
+
+### Setup
+
+For measuring the performance of fitting a Gaussian mixture model, they fix the number of iterations
+after initialization to 100 to not measure any variances in the convergence criterion. For
+initialization, they further set the known means that were used to generate data to not run into
+issues of degenerate covariance matrices. Thus, all benchmarks essentially measure the performance
+after K-means initialization has been run. Benchmarks for K-means itself are listed below.
+
+### Results
+
+| | Scikit-Learn | TorchGMM CPU (full) | TorchGMM CPU (batches) | TorchGMM GPU (full) | TorchGMM GPU (batches) |
+| ------------------ | ------------ | ------------------- | ---------------------- | ------------------- | ---------------------- |
+| `[10k, 8] -> 4` | **352 ms** | 649 ms | 3.9 s | 358 ms | 3.6 s |
+| `[100k, 32] -> 16` | 18.4 s | 4.3 s | 10.0 s | **527 ms** | 3.9 s |
+| `[1M, 64] -> 64` | 730 s | 196 s | 284 s | **7.7 s** | 15.3 s |
+
+Training Duration for Diagonal Covariance (`[num_datapoints, num_features] -> num_components`)
+
+| | Scikit-Learn | TorchGMM CPU (full) | TorchGMM CPU (batches) | TorchGMM GPU (full) | TorchGMM GPU (batches) |
+| ------------------ | ------------ | ------------------- | ---------------------- | ------------------- | ---------------------- |
+| `[10k, 8] -> 4` | 699 ms | 570 ms | 3.6 s | **356 ms** | 3.3 s |
+| `[100k, 32] -> 16` | 72.2 s | 12.1 s | 16.1 s | **919 ms** | 3.8 s |
+| `[1M, 64] -> 64` | -- | -- | -- | -- | **63.4 s** |
+
+Training Duration for Tied Covariance (`[num_datapoints, num_features] -> num_components`)
+
+| | Scikit-Learn | TorchGMM CPU (full) | TorchGMM CPU (batches) | TorchGMM GPU (full) | TorchGMM GPU (batches) |
+| ------------------ | ------------ | ------------------- | ---------------------- | ------------------- | ---------------------- |
+| `[10k, 8] -> 4` | 1.1 s | 679 ms | 4.1 s | **648 ms** | 4.4 s |
+| `[100k, 32] -> 16` | 110 s | 13.5 s | 21.2 s | **2.4 s** | 7.8 s |
+
+Training Duration for Full Covariance (`[num_datapoints, num_features] -> num_components`)
+
+### Summary
+
+TorchGMM's implementation of the Gaussian mixture model is markedly more efficient than the one found
+in scikit-learn. Even on the CPU, TorchGMM outperforms scikit-learn significantly at a 100k
+datapoints already. When moving to the GPU, however, TorchGMM unfolds its full potential and yields
+speed ups at around 100x. For larger datasets, mini-batch training is the only alternative. TorchGMM
+fully supports that while the training is approximately twice as large as when training using the
+full data. The reason for this is that the M-step of the EM algorithm needs to be split across
+epochs, which, in turn, requires to replay the E-step.
+
+## K-Means
+
+### Setup
+
+For the scikit-learn implementation, they use Lloyd's algorithm instead of Elkan's algorithm to have
+a useful comparison with TorchGMM (which implements Lloyd's algorithm).
+
+Further, they fix the number of iterations after initialization to 100 to not measure any variances
+in the convergence criterion.
+
+### Results
+
+| | Scikit-Learn | TorchGMM CPU (full) | TorchGMM CPU (batches) | TorchGMM GPU (full) | TorchGMM GPU (batches) |
+| ------------------- | ------------ | ------------------- | ---------------------- | ------------------- | ---------------------- |
+| `[10k, 8] -> 4` | **13 ms** | 412 ms | 797 ms | 387 ms | 2.1 s |
+| `[100k, 32] -> 16` | **311 ms** | 2.1 s | 3.4 s | 707 ms | 2.5 s |
+| `[1M, 64] -> 64` | 10.0 s | 73.6 s | 58.1 s | **8.2 s** | 10.0 s |
+| `[10M, 128] -> 128` | 254 s | -- | -- | -- | **133 s** |
+
+Training Duration for Random Initialization (`[num_datapoints, num_features] -> num_clusters`)
+
+| | Scikit-Learn | TorchGMM CPU (full) | TorchGMM CPU (batches) | TorchGMM GPU (full) | TorchGMM GPU (batches) |
+| ------------------- | ------------ | ------------------- | ---------------------- | ------------------- | ---------------------- |
+| `[10k, 8] -> 4` | **15 ms** | 170 ms | 930 ms | 431 ms | 2.4 s |
+| `[100k, 32] -> 16` | **542 ms** | 2.3 s | 4.3 s | 840 ms | 3.2 s |
+| `[1M, 64] -> 64` | 25.3 s | 93.4 s | 83.7 s | **13.1 s** | 17.1 s |
+| `[10M, 128] -> 128` | 827 s | -- | -- | -- | **369 s** |
+
+Training Duration for K-Means++ Initialization (`[num_datapoints, num_features] -> num_clusters`)
+
+### Summary
+
+As it turns out, it is really hard to outperform the implementation found in scikit-learn.
+Especially if little data is available, the overhead of PyTorch and PyTorch Lightning renders
+TorchGMM comparatively slow. However, as more data is available, TorchGMM starts to become relatively
+faster and, when leveraging the GPU, it finally outperforms scikit-learn for a dataset size of 1M
+datapoints. Nonetheless, the improvement is marginal.
diff --git a/docs/contributing.md b/docs/contributing.md
index c5bb81a..1d819aa 100644
--- a/docs/contributing.md
+++ b/docs/contributing.md
@@ -103,7 +103,6 @@ Please write documentation for new or changed features and use-cases. This proje
- [Numpy-style docstrings][numpydoc] (through the [napoloen][numpydoc-napoleon] extension).
- Jupyter notebooks as tutorials through [myst-nb][] (See [Tutorials with myst-nb](#tutorials-with-myst-nb-and-jupyter-notebooks))
- [Sphinx autodoc typehints][], to automatically reference annotated input and output types
-- Citations (like {cite:p}`Virshup_2023`) can be included with [sphinxcontrib-bibtex](https://sphinxcontrib-bibtex.readthedocs.io/)
See the [scanpy developer docs](https://scanpy.readthedocs.io/en/latest/dev/documentation.html) for more information
on how to write documentation.
diff --git a/docs/index.md b/docs/index.md
index 8b5f298..a4653ad 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -9,7 +9,5 @@
api.md
changelog.md
contributing.md
-references.md
-
-notebooks/example
+benchmark.md
```
diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb
deleted file mode 100644
index efb1685..0000000
--- a/docs/notebooks/example.ipynb
+++ /dev/null
@@ -1,171 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Example notebook"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "from anndata import AnnData\n",
- "import pandas as pd\n",
- "import torchgmm"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "adata = AnnData(np.random.normal(size=(20, 10)))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "With myst it is possible to link in the text cell of a notebook such as this one the documentation of a function or a class.\n",
- "\n",
- "Let's take as an example the function {func}`torchgmm.pp.basic_preproc`. \n",
- "You can see that by clicking on the text, the link redirects to the API documentation of the function. \n",
- "Check the raw markdown of this cell to understand how this is specified.\n",
- "\n",
- "This works also for any package listed by `intersphinx`. Go to `docs/conf.py` and look for the `intersphinx_mapping` variable. \n",
- "There, you will see a list of packages (that this package is dependent on) for which this functionality is supported. \n",
- "\n",
- "For instance, we can link to the class {class}`anndata.AnnData`, to the attribute {attr}`anndata.AnnData.obs` or the method {meth}`anndata.AnnData.write`.\n",
- "\n",
- "Again, check the raw markdown of this cell to see how each of these links are specified.\n",
- "\n",
- "You can read more about this in the [intersphinx page](https://www.sphinx-doc.org/en/master/usage/extensions/intersphinx.html) and the [myst page](https://myst-parser.readthedocs.io/en/v0.15.1/syntax/syntax.html#roles-an-in-line-extension-point)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Implement a preprocessing function here."
- ]
- },
- {
- "data": {
- "text/plain": [
- "0"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "torchgmm.pp.basic_preproc(adata)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "
\n",
- "\n",
- "
\n",
- " \n",
- " \n",
- " | \n",
- " A | \n",
- " B | \n",
- "
\n",
- " \n",
- " \n",
- " \n",
- " 0 | \n",
- " a | \n",
- " 1 | \n",
- "
\n",
- " \n",
- " 1 | \n",
- " b | \n",
- " 2 | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " c | \n",
- " 3 | \n",
- "
\n",
- " \n",
- "
\n",
- "
"
- ],
- "text/plain": [
- " A B\n",
- "0 a 1\n",
- "1 b 2\n",
- "2 c 3"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "pd.DataFrame().assign(A=[\"a\", \"b\", \"c\"], B=[1, 2, 3])"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3.9.12 ('squidpy39')",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.3"
- },
- "vscode": {
- "interpreter": {
- "hash": "ae6466e8d4f517858789b5c9e8f0ed238fb8964458a36305fca7bddc149e9c64"
- }
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/docs/references.bib b/docs/references.bib
index 9f5bed4..e69de29 100644
--- a/docs/references.bib
+++ b/docs/references.bib
@@ -1,10 +0,0 @@
-@article{Virshup_2023,
- doi = {10.1038/s41587-023-01733-8},
- url = {https://doi.org/10.1038%2Fs41587-023-01733-8},
- year = 2023,
- month = {apr},
- publisher = {Springer Science and Business Media {LLC}},
- author = {Isaac Virshup and Danila Bredikhin and Lukas Heumos and Giovanni Palla and Gregor Sturm and Adam Gayoso and Ilia Kats and Mikaela Koutrouli and Philipp Angerer and Volker Bergen and Pierre Boyeau and Maren Büttner and Gokcen Eraslan and David Fischer and Max Frank and Justin Hong and Michal Klein and Marius Lange and Romain Lopez and Mohammad Lotfollahi and Malte D. Luecken and Fidel Ramirez and Jeffrey Regier and Sergei Rybakov and Anna C. Schaar and Valeh Valiollah Pour Amiri and Philipp Weiler and Galen Xing and Bonnie Berger and Dana Pe'er and Aviv Regev and Sarah A. Teichmann and Francesca Finotello and F. Alexander Wolf and Nir Yosef and Oliver Stegle and Fabian J. Theis and},
- title = {The scverse project provides a computational ecosystem for single-cell omics data analysis},
- journal = {Nature Biotechnology}
-}
diff --git a/docs/references.md b/docs/references.md
deleted file mode 100644
index 00ad6a6..0000000
--- a/docs/references.md
+++ /dev/null
@@ -1,5 +0,0 @@
-# References
-
-```{bibliography}
-:cited:
-```
diff --git a/pyproject.toml b/pyproject.toml
index b651b52..e202a38 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,10 +4,10 @@ requires = ["hatchling"]
[project]
name = "torchgmm"
-version = "0.0.1"
+version = "0.1.1"
description = "Run Gaussian Mixture Models on single or multiple CPUs/GPUs"
readme = "README.md"
-requires-python = ">=3.10"
+requires-python = ">=3.8"
license = {file = "LICENSE"}
authors = [
{name = "Marco Varrone"},
@@ -22,6 +22,10 @@ dependencies = [
"anndata",
# for debug logging (referenced from the issue template)
"session-info",
+ "numpy>=1.20.3,<2.0.0",
+ "pytorch-lightning>=1.6.0",
+ "torch>1.11.0",
+ "torchmetrics>=0.6"
]
[project.optional-dependencies]
@@ -47,6 +51,9 @@ doc = [
test = [
"pytest",
"coverage",
+ "flaky",
+ "pytest-benchmark",
+ "scikit-learn"
]
[tool.coverage.run]
@@ -110,6 +117,10 @@ ignore = [
"D203",
# We want docstrings to start immediately after the opening triple quote
"D213",
+ "D102",
+ "D205",
+ "D200",
+ "B024"
]
[tool.ruff.lint.pydocstyle]
@@ -127,7 +138,5 @@ skip = [
"src/**/basic.py",
"docs/api.md",
"docs/changelog.md",
- "docs/references.bib",
- "docs/references.md",
- "docs/notebooks/example.ipynb",
+ "docs/benchmark.md",
]
diff --git a/src/torchgmm/__init__.py b/src/torchgmm/__init__.py
index d2e0979..edd7287 100644
--- a/src/torchgmm/__init__.py
+++ b/src/torchgmm/__init__.py
@@ -1,7 +1,34 @@
-from importlib.metadata import version
+import logging
+import warnings
-from . import pl, pp, tl
+import torchgmm.base
-__all__ = ["pl", "pp", "tl"]
+# This is taken from PyTorch Lightning and ensures that logging for this package is enabled
+_root_logger = logging.getLogger()
+_logger = logging.getLogger(__name__)
+_logger.setLevel(logging.INFO)
+if not _root_logger.hasHandlers():
+ _logger.addHandler(logging.StreamHandler())
+ _logger.propagate = False
-__version__ = version("torchgmm")
+# This disables most logs generated by PyTorch Lightning
+logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)
+warnings.filterwarnings(action="ignore", message=".*Consider increasing the value of the `num_workers` argument.*")
+warnings.filterwarnings(action="ignore", message=".*`LightningModule.configure_optimizers` returned `None`.*")
+warnings.filterwarnings(action="ignore", message=".*`LoggerConnector.gpus_metrics` was deprecated in v1.5.*")
+
+
+# We also want to define a function which silences info logs
+def set_logging_level(level: int) -> None:
+ """
+ Enables or disables logging for the entire module. By default, logging is enabled.
+
+ Args:
+ enabled: Whether to enable logging.
+ """
+ _logger.setLevel(level)
+ torchgmm.base.set_logging_level(level)
+
+
+# Export
+__all__ = ["set_logging_level"]
diff --git a/src/torchgmm/base/__init__.py b/src/torchgmm/base/__init__.py
new file mode 100644
index 0000000..ed9e6c1
--- /dev/null
+++ b/src/torchgmm/base/__init__.py
@@ -0,0 +1,24 @@
+import logging
+
+from .estimator import BaseEstimator, ConfigurableBaseEstimator
+
+# This is taken from PyTorch Lightning and ensures that logging for this package is enabled
+_root_logger = logging.getLogger()
+_logger = logging.getLogger(__name__)
+_logger.setLevel(logging.INFO)
+if not _root_logger.hasHandlers():
+ _logger.addHandler(logging.StreamHandler())
+ _logger.propagate = False
+
+
+def set_logging_level(level: int) -> None:
+ """
+ Enables or disables logging for the entire module. By default, logging is enabled.
+
+ Args:
+ enabled: Whether to enable logging.
+ """
+ _logger.setLevel(level)
+
+
+__all__ = ["BaseEstimator", "ConfigurableBaseEstimator"]
diff --git a/src/torchgmm/base/data/__init__.py b/src/torchgmm/base/data/__init__.py
new file mode 100644
index 0000000..75d6b71
--- /dev/null
+++ b/src/torchgmm/base/data/__init__.py
@@ -0,0 +1,14 @@
+from .collation import collate_tensor, collate_tuple
+from .loader import DataLoader
+from .sampler import RangeBatchSampler
+from .types import DataLoaderLike, TensorLike, dataset_from_tensors
+
+__all__ = [
+ "DataLoader",
+ "DataLoaderLike",
+ "RangeBatchSampler",
+ "TensorLike",
+ "collate_tensor",
+ "collate_tuple",
+ "dataset_from_tensors",
+]
diff --git a/src/torchgmm/base/data/collation.py b/src/torchgmm/base/data/collation.py
new file mode 100644
index 0000000..395856a
--- /dev/null
+++ b/src/torchgmm/base/data/collation.py
@@ -0,0 +1,23 @@
+from typing import Tuple
+
+import torch
+
+
+def collate_tuple(batch: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
+ """
+ Collate a tuple of batch items by returning the input tuple.
+
+ This is the default used by
+ :class:`~torchgmm.base.data.DataLoader` when slices are cut from the underlying data source.
+ """
+ return batch
+
+
+def collate_tensor(batch: Tuple[torch.Tensor, ...]) -> torch.Tensor:
+ """
+ Collates a tuple of batch items into the first tensor.
+
+ Might be useful if only a single tensor is passed to
+ :class:`~torchgmm.base.data.DataLoader`.
+ """
+ return batch[0]
diff --git a/src/torchgmm/base/data/loader.py b/src/torchgmm/base/data/loader.py
new file mode 100644
index 0000000..3a5f88d
--- /dev/null
+++ b/src/torchgmm/base/data/loader.py
@@ -0,0 +1,83 @@
+from typing import Any, Iterator, TypeVar
+
+try:
+ from pytorch_lightning.overrides.distributed import _IndexBatchSamplerWrapper as IndexBatchSamplerWrapper
+except ImportError:
+ from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
+from torch.utils.data import DataLoader as TorchDataLoader
+from torch.utils.data import Dataset, TensorDataset
+from torch.utils.data.sampler import SequentialSampler
+
+from .collation import collate_tuple
+from .sampler import RangeBatchSampler
+
+T_co = TypeVar("T_co", covariant=True)
+
+
+class DataLoader(TorchDataLoader[T_co]):
+ """Extension for PyTorch's builtin dataloader. This implementation allows
+ to retrieve contiguous indices from a
+ :class:`~torch.utils.data.TensorDataset` orders of magnitude faster. The
+ data loader, thus, enables to implement traditional machine learning
+ methods that exhibit a speed similar to the implementations found in
+ Scikit-learn.
+
+ Note:
+ Retrieving contiguous indices is only possible when all of the following conditions apply:
+
+ - ``shuffle=False`` or ``batch_sampler`` is of type
+ :class:`~torchgmm.base.data.RangeBatchSampler`
+ - ``sampler is None``
+ - ``num_workers=0``
+ - ``dataset`` is not iterable
+ """
+
+ def __init__(self, dataset: Dataset[T_co], **kwargs: Any):
+ """
+ Args:
+ dataset: The dataset from which to load the data.
+ kwargs: Keyword arguments passed to :meth:`torch.utils.data.DataLoader.__init__`.
+ """
+ if (
+ not kwargs.get("shuffle", False)
+ and "sampler" not in kwargs
+ and "batch_sampler" not in kwargs
+ and kwargs.get("num_workers", 0) == 0
+ and isinstance(dataset, TensorDataset)
+ ):
+ kwargs["batch_sampler"] = RangeBatchSampler(
+ SequentialSampler(dataset),
+ batch_size=kwargs.get("batch_size", 1),
+ drop_last=kwargs.get("drop_last", False),
+ )
+ kwargs.pop("batch_size", None)
+ kwargs.pop("shuffle", None)
+ kwargs.pop("drop_last", None)
+ kwargs.setdefault("collate_fn", collate_tuple)
+
+ super().__init__(dataset, **kwargs) # type: ignore
+
+ self.custom_batching = self.num_workers == 0 and (
+ isinstance(self.batch_sampler, RangeBatchSampler)
+ or (
+ self.batch_sampler is not None
+ and hasattr(self.batch_sampler, "sampler")
+ and isinstance(self.batch_sampler.sampler, RangeBatchSampler)
+ )
+ or (
+ isinstance(self.batch_sampler, IndexBatchSamplerWrapper)
+ and isinstance(self.batch_sampler._batch_sampler, RangeBatchSampler) # type: ignore
+ )
+ )
+
+ def __iter__(self) -> Iterator[Any]: # pylint: disable=inconsistent-return-statements
+ if not self.custom_batching:
+ yield from super().__iter__()
+ return
+
+ for indices in self.batch_sampler:
+ if isinstance(indices, range):
+ subscript = slice(indices.start, indices.stop)
+ yield self.collate_fn(tuple(t[subscript] for t in self.dataset.tensors))
+ else:
+ yield self.collate_fn(tuple(t[indices] for t in self.dataset.tensors))
diff --git a/src/torchgmm/base/data/sampler.py b/src/torchgmm/base/data/sampler.py
new file mode 100644
index 0000000..7040326
--- /dev/null
+++ b/src/torchgmm/base/data/sampler.py
@@ -0,0 +1,40 @@
+import math
+from typing import Iterator
+
+from torch.utils.data import Sampler
+from torch.utils.data.sampler import SequentialSampler
+
+
+class RangeBatchSampler(Sampler[range]):
+ """
+ Sampler providing batches of contiguous indices.
+
+ This sampler can be used with
+ :class:`torchgmm.base.data.DataLoader` to provide significant speedups for tensor datasets.
+ """
+
+ def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool = False):
+ """
+ Args:
+ sampler: The sampler providing indices. Must be a sequential sampler. Note that the
+ only purpose of this sampler is to determine its length.
+ batch_size: The number of items to sample for each batch.
+ drop_last: Whether to drop the last batch if ``num_items`` is not divisible by
+ ``batch_size``.
+ """
+ assert isinstance(sampler, SequentialSampler), f"{self.__class__.__name__} only works with sequential samplers."
+
+ super().__init__(None)
+ self.dataset_size = len(sampler)
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+
+ def __len__(self) -> int:
+ if self.drop_last:
+ return self.dataset_size // self.batch_size
+ return math.ceil(self.dataset_size / self.batch_size)
+
+ def __iter__(self) -> Iterator[range]:
+ for i in range(len(self)):
+ sample = range(i * self.batch_size, (i + 1) * self.batch_size)
+ yield sample
diff --git a/src/torchgmm/base/data/types.py b/src/torchgmm/base/data/types.py
new file mode 100644
index 0000000..eda3709
--- /dev/null
+++ b/src/torchgmm/base/data/types.py
@@ -0,0 +1,48 @@
+from typing import Sequence, TypeVar, Union
+
+import numpy as np
+import numpy.typing as npt
+import torch
+from torch.utils.data import DataLoader, TensorDataset
+
+T = TypeVar("T")
+
+TensorLike = Union[
+ npt.NDArray[np.float32],
+ torch.Tensor,
+]
+TensorLike.__doc__ = """
+Type annotation for functions accepting any kind of tensor data as input. Consider using this
+annotation if your methods in an estimator derived from :class:`torchgmm.base.BaseEstimator` work on
+tensors.
+"""
+
+DataLoaderLike = Union[
+ DataLoader[T],
+ Sequence[DataLoader[T]],
+]
+DataLoaderLike.__doc__ = """
+Generic type annotation for functions accepting any data loader as input. Consider using this
+annotation for the implementation of methods in an estimator derived from
+:class:`torchgmm.base.BaseEstimator`.
+"""
+
+
+def dataset_from_tensors(*data: TensorLike) -> TensorDataset:
+ """
+ Transforms a set of tensor-like items into a datasets.
+
+ Args:
+ data: The tensor-like items.
+
+ Returns
+ -------
+ The dataset.
+ """
+ return TensorDataset(*[_to_tensor(t) for t in data])
+
+
+def _to_tensor(data: TensorLike) -> torch.Tensor:
+ if isinstance(data, np.ndarray):
+ return torch.from_numpy(data)
+ return data
diff --git a/src/torchgmm/base/estimator/__init__.py b/src/torchgmm/base/estimator/__init__.py
new file mode 100644
index 0000000..9294f65
--- /dev/null
+++ b/src/torchgmm/base/estimator/__init__.py
@@ -0,0 +1,12 @@
+from .base import BaseEstimator
+from .configurable import ConfigurableBaseEstimator
+from .exception import NotFittedError
+from .mixins import PredictorMixin, TransformerMixin
+
+__all__ = [
+ "BaseEstimator",
+ "ConfigurableBaseEstimator",
+ "NotFittedError",
+ "PredictorMixin",
+ "TransformerMixin",
+]
diff --git a/src/torchgmm/base/estimator/_protocols.py b/src/torchgmm/base/estimator/_protocols.py
new file mode 100644
index 0000000..e9fd696
--- /dev/null
+++ b/src/torchgmm/base/estimator/_protocols.py
@@ -0,0 +1,19 @@
+# pylint: disable=missing-class-docstring,missing-function-docstring
+
+from typing import Protocol, TypeVar
+
+D_contra = TypeVar("D_contra", contravariant=True)
+R_co = TypeVar("R_co", covariant=True)
+E = TypeVar("E", bound="Estimator") # type: ignore
+
+
+class Estimator(Protocol[D_contra]):
+ def fit(self: E, data: D_contra) -> E: ...
+
+
+class Transformer(Estimator[D_contra], Protocol[D_contra, R_co]):
+ def transform(self, data: D_contra) -> R_co: ...
+
+
+class Predictor(Estimator[D_contra], Protocol[D_contra, R_co]):
+ def predict(self, data: D_contra) -> R_co: ...
diff --git a/src/torchgmm/base/estimator/base.py b/src/torchgmm/base/estimator/base.py
new file mode 100644
index 0000000..19be9fe
--- /dev/null
+++ b/src/torchgmm/base/estimator/base.py
@@ -0,0 +1,360 @@
+from __future__ import annotations
+
+import copy
+import inspect
+import json
+import logging
+import pickle
+import warnings
+from abc import ABC
+from pathlib import Path
+from typing import Any, TypeVar
+
+import pytorch_lightning as pl
+from torch.utils.data import DataLoader
+
+from torchgmm.base.utils.path import PathType
+
+from .exception import NotFittedError
+
+E = TypeVar("E", bound="BaseEstimator") # type: ignore
+T = TypeVar("T")
+
+logger = logging.getLogger(__name__)
+
+
+class BaseEstimator(ABC):
+ """
+ Base estimator class that all estimators should inherit from. This base estimator does not
+ enforce the implementation of any methods, but users should follow the Scikit-learn guide on
+ implementing estimators (which can be found `here `_). Some of the methods mentioned in this guide are
+ already implemented in this base estimator and work as expected if the aspects listed below are
+ followed.
+
+ In contrast to Scikit-learn's estimator, this estimator is strongly typed and integrates well
+ with PyTorch Lightning. Most importantly, it provides the :meth:`trainer` method which returns
+ a fully configured trainer to be used by other methods. The configuration is stored in the
+ estimator and can be adjusted by passing parameters to ``default_params``, ``user_params`` and
+ ``overwrite_params`` when calling ``super().__init__()``. By default, the base estimator sets
+ the following flags:
+
+ - Logging is disabled (``logger=False``).
+ - Logging is performed at every step (``log_every_n_steps=1``).
+ - The progress bar is only enabled (``enable_progress_bar``) if TorchGMM's logging level is
+ ``INFO`` or more verbose.
+ - Checkpointing is only enabled (``enable_checkpointing``) if TorchGMM's logging level is
+ ``DEBUG`` or more verbose.
+ - The model summary is only enabled (``enable_model_summary``) if TorchGMM's logging level is
+ ``DEBUG`` or more verbose.
+
+ Note that the logging level can be changed via :meth:`torchgmm.base.set_logging_level`.
+
+ When subclassing this base estimator, users should take care of the following aspects:
+
+ - All parameters passed to the initializer must be assigned to attributes with the same name.
+ This ensures that :meth:`get_params` and :meth:`set_params` work as expected. Parameters that
+ are passed to the trainer *must* be named ``trainer_params`` and should not be manually
+ assigned to an attribute (this is handled by the base estimator).
+ - Fitted attributes must (1) have a single trailing underscore (e.g. ``model_``) and (2) be
+ defined as annotations. This ensures that :meth:`save` and :meth:`load` properly manage the
+ estimator's persistence.
+ """
+
+ def __init__(
+ self,
+ *,
+ default_params: dict[str, Any] | None = None,
+ user_params: dict[str, Any] | None = None,
+ overwrite_params: dict[str, Any] | None = None,
+ ):
+ """
+ Args:
+ default_params: Estimator-specific parameters that provide defaults for configuring the
+ PyTorch Lightning trainer. An example might be setting ``max_epochs``. Overwrites
+ the default parameters established by the base estimator.
+ user_params: User-specific parameters that configure the PyTorch Lightning trainer.
+ This dictionary should be passed through from a ``trainer_params`` init argument in
+ subclasses. Overwrites any of the default parameters.
+ overwrite_params: PyTorch Lightning trainer flags that need to be ensured independently
+ of user-provided parameters. For example, ``max_epochs`` could be fixed to a
+ certain value.
+ """
+ self.trainer_params_user = user_params
+ self.trainer_params = {
+ **{
+ "logger": False,
+ "log_every_n_steps": 1,
+ "enable_progress_bar": logger.getEffectiveLevel() <= logging.INFO,
+ "enable_checkpointing": logger.getEffectiveLevel() <= logging.DEBUG,
+ "enable_model_summary": logger.getEffectiveLevel() <= logging.DEBUG,
+ },
+ **(default_params or {}),
+ **(user_params or {}),
+ **(overwrite_params or {}),
+ }
+
+ def trainer(self, **kwargs: Any) -> pl.Trainer:
+ """
+ Returns the trainer as configured by the estimator. Typically, this method is only called
+ by functions in the estimator.
+
+ Args:
+ kwargs: Additional arguments that override the trainer arguments registered in the
+ initializer of the estimator.
+
+ Returns
+ -------
+ A fully initialized PyTorch Lightning trainer.
+
+ Note:
+ This function should be preferred over initializing the trainer directly. It ensures
+ that the returned trainer correctly deals with TorchGMM components that may be
+ introduced in the future.
+ """
+ return pl.Trainer(**{**self.trainer_params, **kwargs})
+
+ # ---------------------------------------------------------------------------------------------
+ # PERSISTENCE
+
+ @property
+ def persistent_attributes(self) -> list[str]:
+ """
+ Returns the list of fitted attributes that ought to be saved and loaded.
+
+ By default, this encompasses all annotations.
+ """
+ return list(self.__annotations__.keys())
+
+ def save(self, path: PathType) -> None:
+ """Saves the estimator to the provided directory. It saves a file named
+ ``estimator.pickle`` for the configuration of the estimator and
+ additional files for the fitted model (if applicable). For more
+ information on the files saved for the fitted model or for more
+ customization, look at :meth:`get_params` and
+ :meth:`torchgmm.base.nn.Configurable.save`.
+
+ Args:
+ path: The directory to which all files should be saved.
+
+ Note:
+ This method may be called regardless of whether the estimator has already been fitted.
+
+ Attention:
+ If the dictionary returned by :meth:`get_params` is not JSON-serializable, this method
+ uses :mod:`pickle` which is not necessarily backwards-compatible.
+ """
+ path = Path(path)
+ assert not path.exists() or path.is_dir(), "Estimators can only be saved to a directory."
+
+ path.mkdir(parents=True, exist_ok=True)
+ self.save_parameters(path)
+ try:
+ self.save_attributes(path)
+ except NotFittedError:
+ # In case attributes are not fitted, we just don't save them
+ pass
+
+ def save_parameters(self, path: Path) -> None:
+ """
+ Saves the parameters of this estimator. By default, it uses JSON and falls back to
+ :mod:`pickle`. It subclasses use non-primitive types as parameters, they should overwrite
+ this method.
+
+ Typically, this method should not be called directly. It is called as part of :meth:`save`.
+
+ Args:
+ path: The directory to which the parameters should be saved.
+ """
+ params = self.get_params()
+ try:
+ data = json.dumps(params, indent=4)
+ with (path / "params.json").open("w+") as f:
+ f.write(data)
+ except TypeError:
+ warnings.warn(
+ f"Failed to serialize parameters of `{self.__class__.__name__}` to JSON. " "Falling back to `pickle`.",
+ stacklevel=2,
+ )
+ with (path / "params.pickle").open("wb+") as f:
+ pickle.dump(params, f)
+
+ def save_attributes(self, path: Path) -> None:
+ """
+ Saves the fitted attributes of this estimator. By default, it uses JSON and falls back to
+ :mod:`pickle`. Subclasses should overwrite this method if non-primitive attributes are
+ fitted.
+
+ Typically, this method should not be called directly. It is called as part of :meth:`save`.
+
+ Args:
+ path: The directory to which the fitted attributed should be saved.
+
+ Raises
+ ------
+ NotFittedError: If the estimator has not been fitted.
+ """
+ if len(self.persistent_attributes) == 0:
+ return
+
+ attributes = {attribute: getattr(self, attribute) for attribute in self.persistent_attributes}
+ try:
+ data = json.dumps(attributes, indent=4)
+ with (path / "attributes.json").open("w+") as f:
+ f.write(data)
+ except TypeError:
+ warnings.warn(
+ f"Failed to serialize fitted attributes of `{self.__class__.__name__}` to JSON. "
+ "Falling back to `pickle`.",
+ stacklevel=2,
+ )
+ with (path / "attributes.pickle").open("wb+") as f:
+ pickle.dump(attributes, f)
+
+ @classmethod
+ def load(cls: type[E], path: PathType) -> E:
+ """
+ Loads the estimator and (if available) the fitted model. This method should only be
+ expected to work to load an estimator that has previously been saved via :meth:`save`.
+
+ Args:
+ path: The directory from which to load the estimator.
+
+ Returns
+ -------
+ The loaded estimator, either fitted or not.
+ """
+ path = Path(path)
+ assert path.is_dir(), "Estimators can only be loaded from a directory."
+
+ estimator = cls.load_parameters(path)
+ try:
+ estimator.load_attributes(path)
+ except FileNotFoundError:
+ warnings.warn(f"Failed to read fitted attributes of `{cls.__name__}` at path '{path}'", stacklevel=2)
+
+ return estimator
+
+ @classmethod
+ def load_parameters(cls: type[E], path: Path) -> E:
+ """
+ Initializes this estimator by loading its parameters. If subclasses overwrite
+ :meth:`save_parameters`, this method should also be overwritten.
+
+ Typically, this method should not be called directly. It is called as part of :meth:`load`.
+
+ Args:
+ path: The directory from which the parameters should be loaded.
+ """
+ json_path = path / "params.json"
+ pickle_path = path / "params.pickle"
+
+ if json_path.exists():
+ with json_path.open() as f:
+ params = json.load(f)
+ else:
+ with pickle_path.open("rb") as f:
+ params = pickle.load(f)
+
+ return cls(**params)
+
+ def load_attributes(self, path: Path) -> None:
+ """
+ Loads the fitted attributes that are stored at the fitted path. If subclasses overwrite
+ :meth:`save_attributes`, this method should also be overwritten.
+
+ Typically, this method should not be called directly. It is called as part of :meth:`load`.
+
+ Args:
+ path: The directory from which the parameters should be loaded.
+
+ Raises
+ ------
+ FileNotFoundError: If the no fitted attributes have been stored.
+ """
+ json_path = path / "attributes.json"
+ pickle_path = path / "attributes.pickle"
+
+ if json_path.exists():
+ with json_path.open() as f:
+ self.set_params(json.load(f))
+ else:
+ with pickle_path.open("rb") as f:
+ self.set_params(pickle.load(f))
+
+ # ---------------------------------------------------------------------------------------------
+ # SKLEARN INTERFACE
+
+ def get_params(self, deep: bool = True) -> dict[str, Any]: # pylint: disable=unused-argument
+ """
+ Returns the estimator's parameters as passed to the initializer.
+
+ Args:
+ deep: Ignored. For Scikit-learn compatibility.
+
+ Returns
+ -------
+ The mapping from init parameters to values.
+ """
+ signature = inspect.signature(self.__class__.__init__)
+ parameters = [p.name for p in signature.parameters.values() if p.name != "self"]
+ return {p: getattr(self, p) for p in parameters}
+
+ def set_params(self: E, values: dict[str, Any]) -> E:
+ """
+ Sets the provided values on the estimator. The estimator is returned as well, but the
+ estimator on which this function is called is also modified.
+
+ Args:
+ values: The values to set.
+
+ Returns
+ -------
+ The estimator where the values have been set.
+ """
+ for key, value in values.items():
+ setattr(self, key, value)
+ return self
+
+ def clone(self: E) -> E:
+ """
+ Clones the estimator without copying any fitted attributes. All parameters of this
+ estimator are copied via :meth:`copy.deepcopy`.
+
+ Returns
+ -------
+ The cloned estimator with the same parameters.
+ """
+ return self.__class__(
+ **{
+ name: param.clone() if isinstance(param, BaseEstimator) else copy.deepcopy(param)
+ for name, param in self.get_params().items()
+ }
+ )
+
+ # ---------------------------------------------------------------------------------------------
+ # SPECIAL METHODS
+
+ def __getattr__(self, key: str) -> Any:
+ if key in self.__dict__:
+ return self.__dict__[key]
+ if key.endswith("_") and not key.endswith("__") and key in self.__annotations__:
+ raise NotFittedError(f"`{self.__class__.__name__}` has not been fitted yet")
+ raise AttributeError(f"Attribute `{key}` does not exist on type `{self.__class__.__name__}`.")
+
+ # ---------------------------------------------------------------------------------------------
+ # PRIVATE
+
+ def _num_batches_per_epoch(self, loader: DataLoader[Any]) -> int:
+ """Returns the number of batches that are run for the given data loader
+ across all processes when using the trainer provided by the
+ :meth:`trainer` method. If ``n`` processes run.
+
+ ``k`` batches each, this method returns ``k * n``.
+ """
+ trainer = self.trainer()
+ num_batches = len(loader) # type: ignore
+ kwargs = trainer.distributed_sampler_kwargs
+ if kwargs is None:
+ return num_batches
+ return num_batches * kwargs.get("num_replicas", 1)
diff --git a/src/torchgmm/base/estimator/configurable.py b/src/torchgmm/base/estimator/configurable.py
new file mode 100644
index 0000000..eda42a6
--- /dev/null
+++ b/src/torchgmm/base/estimator/configurable.py
@@ -0,0 +1,42 @@
+from pathlib import Path
+from typing import Any, Generic, TypeVar
+
+from torchgmm.base.nn._protocols import ConfigurableModule
+from torchgmm.base.utils import get_generic_type
+
+from .base import BaseEstimator
+from .exception import NotFittedError
+
+M = TypeVar("M", bound=ConfigurableModule) # type: ignore
+
+
+class ConfigurableBaseEstimator(BaseEstimator, Generic[M]):
+ """
+ Extension of the base estimator which allows to manage a single model that uses the
+ :class:`torchgmm.base.nn.Configurable` mixin.
+ """
+
+ model_: M
+
+ def save_attributes(self, path: Path) -> None:
+ # First, store simple attributes
+ super().save_attributes(path)
+
+ # Then, store the model
+ self.model_.save(path / "model")
+
+ def load_attributes(self, path: Path) -> None:
+ # First, load simple attributes
+ super().load_attributes(path)
+
+ # Then, load the model
+ model_cls = get_generic_type(self.__class__, ConfigurableBaseEstimator)
+ self.model_ = model_cls.load(path / "model") # type: ignore
+
+ def __getattr__(self, key: str) -> Any:
+ try:
+ return super().__getattr__(key)
+ except AttributeError as e:
+ if key.endswith("_") and not key.endswith("__") and not key.startswith("_"):
+ raise NotFittedError(f"`{self.__class__.__name__}` has not been fitted yet") from e
+ raise e
diff --git a/src/torchgmm/base/estimator/exception.py b/src/torchgmm/base/estimator/exception.py
new file mode 100644
index 0000000..36ef3fe
--- /dev/null
+++ b/src/torchgmm/base/estimator/exception.py
@@ -0,0 +1,5 @@
+class NotFittedError(Exception):
+ """
+ Exception which is raised whenever properties of an estimator are accessed before the estimator
+ has been fitted.
+ """
diff --git a/src/torchgmm/base/estimator/mixins.py b/src/torchgmm/base/estimator/mixins.py
new file mode 100644
index 0000000..28c5356
--- /dev/null
+++ b/src/torchgmm/base/estimator/mixins.py
@@ -0,0 +1,50 @@
+from typing import Generic
+
+from ._protocols import D_contra, Predictor, R_co, Transformer
+
+
+class TransformerMixin(Generic[D_contra, R_co]):
+ """
+ Mixin that provides a ``fit_transform`` method that chains fitting the estimator and
+ transforming the data it was fitted on.
+ """
+
+ def fit_transform(self: Transformer[D_contra, R_co], data: D_contra) -> R_co:
+ """
+ Fits the estimator using the provided data and subsequently transforms the data using the
+ fitted estimator. It simply chains calls to :meth:`fit` and :meth:`transform`.
+
+ Args:
+ data: The data to use for fitting and to transform. The data must have the
+ same type as for the :meth:`fit` method.
+
+ Returns
+ -------
+ The transformed data. Consult the :meth:`transform` documentation for more information
+ on the return type.
+ """
+ return self.fit(data).transform(data)
+
+
+class PredictorMixin(Generic[D_contra, R_co]):
+ """
+ Mixin that provides a ``fit_predict`` method that chains fitting the estimator and making
+ predictions for the data it was fitted on.
+ """
+
+ def fit_predict(self: Predictor[D_contra, R_co], data: D_contra) -> R_co:
+ """
+ Fits the estimator using the provided data and subsequently predicts the labels for the
+ data using the fitted estimator. It simply chains calls to :meth:`fit` and
+ :meth:`predict`.
+
+ Args:
+ data: The data to use for fitting and to predict labels for. The data must have the
+ same type as for the :meth:`fit` method.
+
+ Returns
+ -------
+ The predicted labels. Consult the :meth:`predict` documentation for more information
+ on the return type.
+ """
+ return self.fit(data).predict(data)
diff --git a/src/torchgmm/base/nn/__init__.py b/src/torchgmm/base/nn/__init__.py
new file mode 100644
index 0000000..4e3a037
--- /dev/null
+++ b/src/torchgmm/base/nn/__init__.py
@@ -0,0 +1,3 @@
+from .configurable import Configurable
+
+__all__ = ["Configurable"]
diff --git a/src/torchgmm/base/nn/_protocols.py b/src/torchgmm/base/nn/_protocols.py
new file mode 100644
index 0000000..a11fee5
--- /dev/null
+++ b/src/torchgmm/base/nn/_protocols.py
@@ -0,0 +1,28 @@
+# pylint: disable=missing-class-docstring,missing-function-docstring
+from typing import Generic, Iterator, OrderedDict, Protocol, Tuple, Type, TypeVar
+
+import torch
+from torch import nn
+
+from torchgmm.base.utils import PathType
+
+C_co = TypeVar("C_co", covariant=True)
+M = TypeVar("M", bound="ConfigurableModule") # type: ignore
+
+
+class ConfigurableModule(Protocol, Generic[C_co]):
+ @property
+ def config(self) -> C_co: ...
+
+ @classmethod
+ def load(cls: Type[M], path: PathType) -> M: ...
+
+ def save(self, path: PathType, compile_model: bool = False) -> None: ...
+
+ def save_config(self, path: PathType) -> None: ...
+
+ def named_children(self) -> Iterator[Tuple[str, nn.Module]]: ...
+
+ def state_dict(self) -> OrderedDict[str, torch.Tensor]: ...
+
+ def load_state_dict(self, state_dict: OrderedDict[str, torch.Tensor]) -> None: ...
diff --git a/src/torchgmm/base/nn/configurable.py b/src/torchgmm/base/nn/configurable.py
new file mode 100644
index 0000000..d028a5b
--- /dev/null
+++ b/src/torchgmm/base/nn/configurable.py
@@ -0,0 +1,170 @@
+from __future__ import annotations
+
+import dataclasses
+import json
+from pathlib import Path
+from typing import Any, Generic
+
+import torch
+from torch import jit, nn
+
+from torchgmm.base.utils import PathType, get_generic_type
+
+from ._protocols import C_co, ConfigurableModule, M
+
+
+class Configurable(Generic[C_co]):
+ """
+ A mixin for any PyTorch module to extend it with storage capabilities.
+
+ By passing a single
+ configuration object to the initializer, this mixin allows the module to be extended with
+ :meth:`save` and :meth:`load` methods. These methods allow to (1) save the model along with
+ its configuration (i.e. architecture) and (2) to load the model without instantiating an
+ instance of the class.
+ """
+
+ def __init__(self, config: C_co, *args: Any, **kwargs: Any):
+ """
+ Args:
+ config: The configuration of the architecture.
+ args: Positional arguments that ought to be passed to the superclass.
+ kwargs: Keyword arguments that ought to be passed to the superclass.
+ """
+ assert dataclasses.is_dataclass(config), "Configuration is not a dataclass."
+ assert isinstance(self, nn.Module), "Configurable mixin can only be applied to subclasses of `torch.nn.Module`."
+
+ super().__init__(*args, **kwargs)
+ self.config = config
+
+ @jit.unused
+ def save_config(self: ConfigurableModule[C_co], path: Path) -> None:
+ """
+ Saves only the module's configuration to a file named ``config.json`` in the specified
+ directory. This method should not be called directly. It is called as part of :meth:`save`.
+
+ Args:
+ path: The directory to which to save the configuration and parameter files. The
+ directory may or may not exist but no parent directories are created.
+ """
+ path.mkdir(parents=False, exist_ok=True)
+ with (path / "config.json").open("w+") as f:
+ json.dump(dataclasses.asdict(self.config), f, indent=4)
+
+ @jit.unused
+ def save(self: ConfigurableModule[C_co], path: PathType, compile_model: bool = False) -> None:
+ """
+ Saves the module's configuration and parameters to files in the specified directory. It
+ creates two files, namely ``config.json`` and ``parameters.pt`` which contain the
+ configuration and parameters, respectively.
+
+ Args:
+ path: The directory to which to save the configuration and parameter files. The
+ directory may or may not exist but no parent directories are created.
+ compile_model: Whether the model should be compiled via TorchScript. An additional file
+ called ``model.ptc`` will then be stored. Note that you can simply load the
+ compiled model via :meth:`torch.jit.load` at a later point.
+ """
+ path = Path(path)
+ assert not path.exists() or path.is_dir(), "Modules can only be saved to a directory."
+
+ path.mkdir(parents=True, exist_ok=True)
+
+ # Store the model's configuration and all parameters
+ self.save_config(path)
+ with (path / "parameters.pt").open("wb+") as f:
+ torch.save(self.state_dict(), f) # pylint: disable=no-member
+
+ # Optionally store the compiled model
+ if compile_model:
+ compiled_model = jit.script(self)
+ with (path / "model.ptc").open("wb+") as f:
+ jit.save(compiled_model, f)
+
+ @classmethod
+ def load_config(cls: type[M], path: Path) -> M:
+ """
+ Loads the module by reading the configuration. Parameters are initialized randomly as if
+ the module would be initialized from scratch. This method should not be called directly. It
+ is called as part of :meth:`load`.
+
+ Args:
+ path: The directory which contains the ``config.json`` to load.
+
+ Returns
+ -------
+ The loaded model.
+
+ Attention:
+ This method must only be called if the module is initializable solely from a
+ configuration.
+ """
+ config_cls = get_generic_type(cls, Configurable)
+ with (path / "config.json").open("r") as f:
+ config_args = json.load(f)
+ config = _init_config(config_cls, config_args)
+
+ return cls(config) # type: ignore
+
+ @classmethod
+ def load(cls: type[M], path: PathType) -> M:
+ """
+ Loads the module's configurations and parameters from files in the specified directory at
+ first. Then, it initializes the model with the stored configurations and loads the
+ parameters. This method is typically used after calling :meth:`save` on the model.
+
+ Args:
+ path: The directory which contains the ``config.json`` and ``parameters.pt`` files to
+ load.
+
+ Returns
+ -------
+ The loaded model.
+
+ Note:
+ You can load modules even after you changed their configuration class. The only
+ requirement is that any new configuration options have a default value.
+ """
+ path = Path(path)
+ assert path.is_dir(), "Modules can only be loaded from a directory."
+
+ # Load the config
+ config_cls = get_generic_type(cls, Configurable)
+ with (path / "config.json").open("r") as f:
+ config_args = json.load(f)
+ config = _init_config(config_cls, config_args)
+
+ # Initialize model
+ model = cls(config) # type: ignore
+ with (path / "parameters.pt").open("rb") as f:
+ state_dict = torch.load(f, weights_only=True)
+ model.load_state_dict(state_dict) # pylint: disable=no-member
+ return model
+
+ def clone(self: M, copy_parameters: bool = True) -> M:
+ """
+ Clones this module by initializing another module with the same configuration.
+
+ Args:
+ copy_parameters: Whether to copy this module's parameters or initialize the new module
+ with random parameters.
+
+ Returns
+ -------
+ The cloned module.
+ """
+ cloned = self.__class__(self.config) # type: ignore
+ if copy_parameters:
+ cloned.load_state_dict(self.state_dict()) # pylint: disable=no-member
+ return cloned
+
+
+def _init_config(target: type[Any], args: dict[str, Any]) -> Any:
+ result = {}
+ for key, val in args.items():
+ arg_type = target.__dataclass_fields__[key].type # type: ignore
+ if dataclasses.is_dataclass(arg_type):
+ result[key] = _init_config(arg_type, val)
+ else:
+ result[key] = val
+ return target(**result)
diff --git a/src/torchgmm/base/utils/__init__.py b/src/torchgmm/base/utils/__init__.py
new file mode 100644
index 0000000..573b572
--- /dev/null
+++ b/src/torchgmm/base/utils/__init__.py
@@ -0,0 +1,4 @@
+from .generics import get_generic_type
+from .path import PathType
+
+__all__ = ["PathType", "get_generic_type"]
diff --git a/src/torchgmm/base/utils/generics.py b/src/torchgmm/base/utils/generics.py
new file mode 100644
index 0000000..5576e42
--- /dev/null
+++ b/src/torchgmm/base/utils/generics.py
@@ -0,0 +1,23 @@
+from typing import Any, Type, get_args, get_origin
+
+
+def get_generic_type(cls: Type[Any], origin: Type[Any], index: int = 0) -> Type[Any]:
+ """
+ Returns the ``index``-th generic type of the superclass ``origin``.
+
+ Args:
+ cls: The class on which to inspect the superclasses.
+ origin: The superclass to look for.
+ index: The index of the generic type of the superclass.
+
+ Returns
+ -------
+ The generic type.
+ """
+ for base in cls.__orig_bases__: # type: ignore
+ if get_origin(base) == origin:
+ args = get_args(base)
+ if not args:
+ raise ValueError(f"`{cls.__name__}` does not provide a generic parameter " f"for `{origin.__name__}`")
+ return get_args(base)[index]
+ raise ValueError(f"`{cls.__name__}` does not inherit from `{origin.__name__}`")
diff --git a/src/torchgmm/base/utils/path.py b/src/torchgmm/base/utils/path.py
new file mode 100644
index 0000000..ded30aa
--- /dev/null
+++ b/src/torchgmm/base/utils/path.py
@@ -0,0 +1,9 @@
+import sys
+from os import PathLike
+from typing import Union
+
+if sys.version_info < (3, 9, 0):
+ # PathLike is not generic for Python 3.9
+ PathType = Union[str, PathLike]
+else:
+ PathType = Union[str, PathLike[str]] # type: ignore
diff --git a/src/torchgmm/bayes/__init__.py b/src/torchgmm/bayes/__init__.py
new file mode 100644
index 0000000..61dac91
--- /dev/null
+++ b/src/torchgmm/bayes/__init__.py
@@ -0,0 +1,5 @@
+from .gmm import GaussianMixture
+
+__all__ = [
+ "GaussianMixture",
+]
diff --git a/src/torchgmm/bayes/core/__init__.py b/src/torchgmm/bayes/core/__init__.py
new file mode 100644
index 0000000..56a388a
--- /dev/null
+++ b/src/torchgmm/bayes/core/__init__.py
@@ -0,0 +1,13 @@
+from .normal import cholesky_precision, covariance, log_normal, sample_normal
+from .types import CovarianceType
+from .utils import covariance_dim, covariance_shape
+
+__all__ = [
+ "cholesky_precision",
+ "log_normal",
+ "sample_normal",
+ "covariance",
+ "CovarianceType",
+ "covariance_dim",
+ "covariance_shape",
+]
diff --git a/src/torchgmm/bayes/core/_jit.py b/src/torchgmm/bayes/core/_jit.py
new file mode 100644
index 0000000..376694b
--- /dev/null
+++ b/src/torchgmm/bayes/core/_jit.py
@@ -0,0 +1,89 @@
+# pylint: disable=missing-function-docstring
+import math
+
+import torch
+
+
+def jit_log_normal(
+ x: torch.Tensor,
+ means: torch.Tensor,
+ precisions_cholesky: torch.Tensor,
+ covariance_type: str,
+) -> torch.Tensor:
+ if covariance_type == "full":
+ # Precision shape is `[num_components, dim, dim]`.
+ log_prob = x.new_empty((x.size(0), means.size(0)))
+ # We loop here to not blow up the size of intermediate matrices
+ for k, (mu, prec_chol) in enumerate(zip(means, precisions_cholesky)):
+ inner = x.matmul(prec_chol) - mu.matmul(prec_chol)
+ log_prob[:, k] = inner.square().sum(1)
+ elif covariance_type == "tied":
+ # Precision shape is `[dim, dim]`.
+ a = x.matmul(precisions_cholesky) # [N, D]
+ b = means.matmul(precisions_cholesky) # [K, D]
+ log_prob = (a.unsqueeze(1) - b).square().sum(-1)
+ else:
+ precisions = precisions_cholesky.square()
+ if covariance_type == "diag":
+ # Precision shape is `[num_components, dim]`.
+ x_prob = torch.matmul(x * x, precisions.t())
+ m_prob = torch.einsum("ij,ij,ij->i", means, means, precisions)
+ xm_prob = torch.matmul(x, (means * precisions).t())
+ else: # covariance_type == "spherical"
+ # Precision shape is `[num_components]`
+ x_prob = torch.ger(torch.einsum("ij,ij->i", x, x), precisions)
+ m_prob = torch.einsum("ij,ij->i", means, means) * precisions
+ xm_prob = torch.matmul(x, means.t() * precisions)
+
+ log_prob = x_prob - 2 * xm_prob + m_prob
+
+ num_features = x.size(1)
+ logdet = _cholesky_logdet(num_features, precisions_cholesky, covariance_type)
+ constant = math.log(2 * math.pi) * num_features
+ return logdet - 0.5 * (constant + log_prob)
+
+
+def _cholesky_logdet(
+ num_features: int,
+ precisions_cholesky: torch.Tensor,
+ covariance_type: str,
+) -> torch.Tensor:
+ if covariance_type == "full":
+ return precisions_cholesky.diagonal(dim1=-2, dim2=-1).log().sum(-1)
+ if covariance_type == "tied":
+ return precisions_cholesky.diagonal().log().sum(-1)
+ if covariance_type == "diag":
+ return precisions_cholesky.log().sum(1)
+ # covariance_type == "spherical"
+ return precisions_cholesky.log() * num_features
+
+
+# -------------------------------------------------------------------------------------------------
+
+
+def jit_sample_normal(
+ num: int,
+ mean: torch.Tensor,
+ cholesky_precisions: torch.Tensor,
+ covariance_type: str,
+) -> torch.Tensor:
+ samples = torch.randn(num, mean.size(0), dtype=mean.dtype, device=mean.device)
+ chol_covariance = _cholesky_covariance(cholesky_precisions, covariance_type)
+
+ if covariance_type in ("tied", "full"):
+ scale = chol_covariance.matmul(samples.unsqueeze(-1)).squeeze(-1)
+ else:
+ scale = chol_covariance * samples
+
+ return mean + scale
+
+
+def _cholesky_covariance(chol_precision: torch.Tensor, covariance_type: str) -> torch.Tensor:
+ # For complex covariance types, invert the
+ if covariance_type in ("tied", "full"):
+ num_features = chol_precision.size(-1)
+ target = torch.eye(num_features, dtype=chol_precision.dtype, device=chol_precision.device)
+ return torch.linalg.solve_triangular(chol_precision, target, upper=True).t()
+
+ # Simple covariance type
+ return chol_precision.reciprocal()
diff --git a/src/torchgmm/bayes/core/normal.py b/src/torchgmm/bayes/core/normal.py
new file mode 100644
index 0000000..9423c83
--- /dev/null
+++ b/src/torchgmm/bayes/core/normal.py
@@ -0,0 +1,117 @@
+import torch
+
+from ._jit import jit_log_normal, jit_sample_normal
+from .types import CovarianceType
+
+
+def cholesky_precision(covariances: torch.Tensor, covariance_type: CovarianceType) -> torch.Tensor:
+ """
+ Computes the Cholesky decompositions of the precision matrices induced by the provided
+ covariance matrices.
+
+ Args:
+ covariances: A tensor of shape ``[num_components, dim, dim]``, ``[dim, dim]``,
+ ``[num_components, dim]``, ``[dim]`` or ``[num_components]`` depending on the
+ ``covariance_type``. These are the covariance matrices of multivariate Normal
+ distributions.
+ covariance_type: The type of covariance for the covariance matrices given.
+
+ Returns
+ -------
+ A tensor of the same shape as ``covariances``, providing the lower-triangular Cholesky
+ decompositions of the precision matrices.
+ """
+ if covariance_type in ("tied", "full"):
+ # Compute Cholesky decomposition
+ cholesky = torch.linalg.cholesky(covariances)
+ # Invert
+ num_features = covariances.size(-1)
+ target = torch.eye(num_features, dtype=covariances.dtype, device=covariances.device)
+ if covariance_type == "full":
+ num_components = covariances.size(0)
+ target = target.unsqueeze(0).expand(num_components, -1, -1)
+ return torch.linalg.solve_triangular(cholesky, target, upper=False).transpose(-2, -1)
+
+ # "Simple" kind of covariance
+ return covariances.sqrt().reciprocal()
+
+
+def covariance(cholesky_precisions: torch.Tensor, covariance_type: CovarianceType) -> torch.Tensor:
+ """
+ Computes the covariances matrices of the provided Cholesky decompositions of the precision
+ matrices. This function is the inverse of :meth:`cholesky_precision`.
+
+ Args:
+ cholesky_precisions: A tensor of shape ``[num_components, dim, dim]``, ``[dim, dim]``,
+ ``[num_components, dim]``, ``[dim]`` or ``[num_components]`` depending on the
+ ``covariance_type``. These are the Cholesky decompositions of the precisions of
+ multivariate Normal distributions.
+ covariance_type: The type of covariance for the covariance matrices given.
+
+ Returns
+ -------
+ A tensor of the same shape as ``cholesky_precisions``, providing the covariance matrices
+ corresponding to the given Cholesky-decomposed precision matrices.
+ """
+ if covariance_type in ("tied", "full"):
+ choleksy_covars = torch.linalg.inv(cholesky_precisions)
+ if covariance_type == "tied":
+ return torch.matmul(choleksy_covars.T, choleksy_covars)
+ return torch.bmm(choleksy_covars.transpose(1, 2), choleksy_covars)
+
+ # "Simple" kind of covariance
+ return (cholesky_precisions**2).reciprocal()
+
+
+def log_normal(
+ x: torch.Tensor,
+ means: torch.Tensor,
+ precisions_cholesky: torch.Tensor,
+ covariance_type: CovarianceType,
+) -> torch.Tensor:
+ """
+ Computes the log-probability of the given data for multiple multivariate Normal distributions
+ defined by their means and covariances.
+
+ Args:
+ x: A tensor of shape ``[num_datapoints, dim]``. This is the data to compute the
+ log-probability for.
+ means: A tensor of shape ``[num_components, dim]``. These are the means of the multivariate
+ Normal distributions.
+ precisions_cholesky: A tensor of shape ``[num_components, dim, dim]``, ``[dim, dim]``,
+ ``[num_components, dim]``, ``[dim]`` or ``[num_components]`` depending on the
+ ``covariance_type``. These are the upper-triangular Cholesky matrices for the inverse
+ covariance matrices (i.e. precision matrices) of the multivariate Normal distributions.
+ covariance_type: The type of covariance for the covariance matrices given.
+
+ Returns
+ -------
+ A tensor of shape ``[num_datapoints, num_components]`` with the log-probabilities for each
+ datapoint and each multivariate Normal distribution.
+ """
+ return jit_log_normal(x, means, precisions_cholesky, covariance_type)
+
+
+def sample_normal(
+ num: int,
+ mean: torch.Tensor,
+ cholesky_precisions: torch.Tensor,
+ covariance_type: CovarianceType,
+) -> torch.Tensor:
+ """
+ Samples the given number of times from the multivariate Normal distribution described by the
+ mean and Cholesky precision.
+
+ Args:
+ num: The number of times to sample.
+ means: A tensor of shape ``[dim]`` with the mean of the distribution to sample from.
+ choleksy_precisions: A tensor of shape ``[dim, dim]``, ``[dim]``, ``[dim]`` or ``[1]``
+ depending on the ``covariance_type``. This is the corresponding Cholesky precision
+ matrix for the mean.
+ covariance_type: The type of covariance for the covariance matrix given.
+
+ Returns
+ -------
+ A tensor of shape ``[num_samples, dim]`` with the samples from the Normal distribution.
+ """
+ return jit_sample_normal(num, mean, cholesky_precisions, covariance_type)
diff --git a/src/torchgmm/bayes/core/types.py b/src/torchgmm/bayes/core/types.py
new file mode 100644
index 0000000..26911cf
--- /dev/null
+++ b/src/torchgmm/bayes/core/types.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+from typing import Literal
+
+CovarianceType = Literal["full", "tied", "diag", "spherical"]
+CovarianceType.__doc__ = """
+The type of covariance to use for a set of multivariate Normal distributions.
+
+- **full**: Each distribution has a full covariance matrix. Covariance matrix is a tensor of shape
+ ``[num_components, num_features, num_features]``.
+- **tied**: All distributions share the same full covariance matrix. Covariance matrix is a tensor
+ of shape ``[num_features, num_features]``.
+- **diag**: Each distribution has a diagonal covariance matrix. Covariance matrix is a tensor of
+ shape ``[num_components, num_features]``.
+- **spherical**: Each distribution has a diagonal covariance matrix which is a multiple of the
+ identity matrix. Covariance matrix is a tensor of shape ``[num_components]``.
+"""
diff --git a/src/torchgmm/bayes/core/utils.py b/src/torchgmm/bayes/core/utils.py
new file mode 100644
index 0000000..1341d1e
--- /dev/null
+++ b/src/torchgmm/bayes/core/utils.py
@@ -0,0 +1,45 @@
+import torch
+
+from .types import CovarianceType
+
+
+def covariance_dim(covariance_type: CovarianceType) -> int:
+ """
+ Returns the number of dimension of the covariance matrix for a set of components.
+
+ Args:
+ covariance_type: The type of covariance to obtain the dimension for.
+
+ Returns
+ -------
+ The number of dimensions.
+ """
+ if covariance_type == "full":
+ return 3
+ if covariance_type in ("tied", "diag"):
+ return 2
+ return 1
+
+
+def covariance_shape(num_components: int, num_features: int, covariance_type: CovarianceType) -> torch.Size:
+ """
+ Returns the expected shape of the covariance matrix for the given number of components with the
+ provided number of features based on the covariance type.
+
+ Args:
+ num_components: The number of Normal distributions to describe with the covariance.
+ num_features: The dimensionality of the Normal distributions.
+ covariance_type: The type of covariance to use.
+
+ Returns
+ -------
+ The expected size of the tensor representing the covariances.
+ """
+ if covariance_type == "full":
+ return torch.Size([num_components, num_features, num_features])
+ if covariance_type == "tied":
+ return torch.Size([num_features, num_features])
+ if covariance_type == "diag":
+ return torch.Size([num_components, num_features])
+ # covariance_type == "spherical"
+ return torch.Size([num_components])
diff --git a/src/torchgmm/bayes/gmm/__init__.py b/src/torchgmm/bayes/gmm/__init__.py
new file mode 100644
index 0000000..112ee67
--- /dev/null
+++ b/src/torchgmm/bayes/gmm/__init__.py
@@ -0,0 +1,8 @@
+from .estimator import GaussianMixture
+from .model import GaussianMixtureModel, GaussianMixtureModelConfig
+
+__all__ = [
+ "GaussianMixture",
+ "GaussianMixtureModel",
+ "GaussianMixtureModelConfig",
+]
diff --git a/src/torchgmm/bayes/gmm/estimator.py b/src/torchgmm/bayes/gmm/estimator.py
new file mode 100644
index 0000000..96f8746
--- /dev/null
+++ b/src/torchgmm/bayes/gmm/estimator.py
@@ -0,0 +1,317 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, List, Tuple, cast
+
+import numpy as np
+import torch
+
+from torchgmm.base import ConfigurableBaseEstimator
+from torchgmm.base.data import DataLoader, TensorLike, collate_tensor, dataset_from_tensors
+from torchgmm.base.estimator import PredictorMixin
+from torchgmm.bayes.core import CovarianceType
+from torchgmm.clustering import KMeans
+
+from .lightning_module import (
+ GaussianMixtureKmeansInitLightningModule,
+ GaussianMixtureLightningModule,
+ GaussianMixtureRandomInitLightningModule,
+)
+from .model import GaussianMixtureModel, GaussianMixtureModelConfig
+from .types import GaussianMixtureInitStrategy
+
+logger = logging.getLogger(__name__)
+
+
+class GaussianMixture(
+ ConfigurableBaseEstimator[GaussianMixtureModel], # type: ignore
+ PredictorMixin[TensorLike, torch.Tensor],
+):
+ """
+ Probabilistic model assuming that data is generated from a mixture of Gaussians.
+
+ The mixture is assumed to be composed of a fixed number of components with individual means
+ and covariances. More information on Gaussian mixture models (GMMs) is available on
+ `Wikipedia `_.
+
+ See Also
+ --------
+ .. currentmodule:: torchgmm.bayes.gmm
+ .. autosummary::
+ :nosignatures:
+ :template: classes/pytorch_module.rst
+
+ GaussianMixtureModel
+ GaussianMixtureModelConfig
+ """
+
+ #: The fitted PyTorch module with all estimated parameters.
+ model_: GaussianMixtureModel
+ #: A boolean indicating whether the model converged during training.
+ converged_: bool
+ #: The number of iterations the model was fitted for, excluding initialization.
+ num_iter_: int
+ #: The average per-datapoint negative log-likelihood at the last training step.
+ nll_: float
+
+ def __init__(
+ self,
+ num_components: int = 1,
+ *,
+ covariance_type: CovarianceType = "diag",
+ init_strategy: GaussianMixtureInitStrategy = "kmeans",
+ init_means: torch.Tensor | None = None,
+ convergence_tolerance: float = 1e-3,
+ covariance_regularization: float = 1e-6,
+ batch_size: int | None = None,
+ trainer_params: dict[str, Any] | None = None,
+ ):
+ """
+ Args:
+ num_components: The number of components in the GMM. The dimensionality of each
+ component is automatically inferred from the data.
+ covariance_type: The type of covariance to assume for all Gaussian components.
+ init_strategy: The strategy for initializing component means and covariances.
+ init_means: An optional initial guess for the means of the components. If provided,
+ must be a tensor of shape ``[num_components, num_features]``. If this is given,
+ the ``init_strategy`` is ignored and the means are handled as if K-means
+ initialization has been run.
+ convergence_tolerance: The change in the per-datapoint negative log-likelihood which
+ implies that training has converged.
+ covariance_regularization: A small value which is added to the diagonal of the
+ covariance matrix to ensure that it is positive semi-definite.
+ batch_size: The batch size to use when fitting the model. If not provided, the full
+ data will be used as a single batch. Set this if the full data does not fit into
+ memory.
+ num_workers: The number of workers to use for loading the data. Only used if a PyTorch
+ dataset is passed to :meth:`fit` or related methods.
+ trainer_params: Initialization parameters to use when initializing a PyTorch Lightning
+ trainer. By default, it disables various stdout logs unless TorchGMM is configured to
+ do verbose logging. Checkpointing and logging are disabled regardless of the log
+ level. This estimator further sets the following overridable defaults:
+
+ - ``max_epochs=100``
+
+ Note:
+ The number of epochs passed to the initializer only define the number of optimization
+ epochs. Prior to that, initialization is run which may perform additional iterations
+ through the data.
+
+ Note:
+ For batch training, the number of epochs run (i.e. the number of passes through the
+ data), does not align with the number of epochs passed to the initializer. This is
+ because the EM algorithm needs to be split up across two epochs. The actual number of
+ minimum/maximum epochs is, thus, doubled. Nonetheless, :attr:`num_iter_` indicates how
+ many EM iterations have been run.
+ """
+ super().__init__(
+ default_params={"max_epochs": 100},
+ user_params=trainer_params,
+ )
+
+ self.num_components = num_components
+ self.covariance_type = covariance_type
+ self.init_strategy = init_strategy
+ self.init_means = init_means
+ self.convergence_tolerance = convergence_tolerance
+ self.covariance_regularization = covariance_regularization
+
+ self.batch_size = batch_size
+
+ def fit(self, data: TensorLike) -> GaussianMixture:
+ """
+ Fits the Gaussian mixture on the provided data, estimating component priors, means and
+ covariances. Parameters are estimated using the EM algorithm.
+
+ Args:
+ data: The tabular data to fit on. The dimensionality of the Gaussian mixture is
+ automatically inferred from this data.
+
+ Returns
+ -------
+ The fitted Gaussian mixture.
+ """
+ if (data.dtype in [torch.float64, np.float64]) and (self.trainer_params.get("precision", 32) == 32):
+ raise ValueError(
+ "Data is of type float64. Transform it to float32 or use trainer_params={'precision': 64}."
+ )
+ # Initialize the model
+ num_features = len(data[0])
+ config = GaussianMixtureModelConfig(
+ num_components=self.num_components,
+ num_features=num_features,
+ covariance_type=self.covariance_type, # type: ignore
+ )
+ self.model_ = GaussianMixtureModel(config)
+
+ # Setup the data loading
+ loader = DataLoader(
+ dataset_from_tensors(data),
+ batch_size=self.batch_size or len(data),
+ collate_fn=collate_tensor,
+ )
+ is_batch_training = self._num_batches_per_epoch(loader) == 1
+
+ # Run k-means if required or copy means
+ if self.init_means is not None:
+ self.model_.means.copy_(self.init_means)
+ elif self.init_strategy in ("kmeans", "kmeans++"):
+ logger.info("Fitting K-means estimator...")
+ params = self.trainer_params_user
+ if self.init_strategy == "kmeans++":
+ params = {**(params or {}), **{"max_epochs": 0}}
+
+ estimator = KMeans(
+ self.num_components,
+ batch_size=self.batch_size,
+ trainer_params=params,
+ ).fit(data)
+ self.model_.means.copy_(estimator.model_.centroids)
+
+ # Run initialization
+ logger.info("Running initialization...")
+ if self.init_strategy in ("kmeans", "kmeans++") and self.init_means is None:
+ module = GaussianMixtureKmeansInitLightningModule(
+ self.model_,
+ covariance_regularization=self.covariance_regularization,
+ )
+ self.trainer(max_epochs=1).fit(module, loader)
+ else:
+ module = GaussianMixtureRandomInitLightningModule(
+ self.model_,
+ covariance_regularization=self.covariance_regularization,
+ is_batch_training=is_batch_training,
+ use_model_means=self.init_means is not None,
+ )
+ self.trainer(max_epochs=1 + int(is_batch_training)).fit(module, loader)
+
+ # Fit model
+ logger.info("Fitting Gaussian mixture...")
+ module = GaussianMixtureLightningModule(
+ self.model_,
+ convergence_tolerance=self.convergence_tolerance,
+ covariance_regularization=self.covariance_regularization,
+ is_batch_training=is_batch_training,
+ )
+ trainer = self.trainer(max_epochs=cast(int, self.trainer_params["max_epochs"]) * (1 + int(is_batch_training)))
+ trainer.fit(module, loader)
+
+ # Assign convergence properties
+ self.num_iter_ = module.current_epoch
+ if is_batch_training:
+ self.num_iter_ //= 2
+ self.converged_ = trainer.should_stop
+ self.nll_ = cast(float, trainer.callback_metrics["nll"].item())
+ return self
+
+ def sample(self, num_datapoints: int) -> torch.Tensor:
+ """
+ Samples datapoints from the fitted Gaussian mixture.
+
+ Args:
+ num_datapoints: The number of datapoints to sample.
+
+ Returns
+ -------
+ A tensor of shape ``[num_datapoints, dim]`` providing the samples.
+
+ Note:
+ This method does not parallelize across multiple processes, i.e. performs no
+ synchronization.
+ """
+ return self.model_.sample(num_datapoints)
+
+ def score(self, data: TensorLike) -> float:
+ """
+ Computes the average negative log-likelihood (NLL) of the provided datapoints.
+
+ Args:
+ data: The datapoints for which to evaluate the NLL.
+
+ Returns
+ -------
+ The average NLL of all datapoints.
+
+ Note:
+ See :meth:`score_samples` to obtain NLL values for individual datapoints.
+ """
+ loader = DataLoader(
+ dataset_from_tensors(data),
+ batch_size=self.batch_size or len(data),
+ collate_fn=collate_tensor,
+ )
+ result = self.trainer().test(GaussianMixtureLightningModule(self.model_), loader, verbose=False)
+ return result[0]["nll"]
+
+ def score_samples(self, data: TensorLike) -> torch.Tensor:
+ """
+ Computes the negative log-likelihood (NLL) of each of the provided datapoints.
+
+ Args:
+ data: The datapoints for which to compute the NLL.
+
+ Returns
+ -------
+ A tensor of shape ``[num_datapoints]`` with the NLL for each datapoint.
+
+ Attention:
+ When calling this function in a multi-process environment, each process receives only
+ a subset of the predictions. If you want to aggregate predictions, make sure to gather
+ the values returned from this method.
+ """
+ loader = DataLoader(
+ dataset_from_tensors(data),
+ batch_size=self.batch_size or len(data),
+ collate_fn=collate_tensor,
+ )
+ result = self.trainer().predict(GaussianMixtureLightningModule(self.model_), loader)
+ return torch.stack([x[1] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)])
+
+ def predict(self, data: TensorLike) -> torch.Tensor:
+ """
+ Computes the most likely components for each of the provided datapoints.
+
+ Args:
+ data: The datapoints for which to obtain the most likely components.
+
+ Returns
+ -------
+ A tensor of shape ``[num_datapoints]`` with the indices of the most likely components.
+
+ Note:
+ Use :meth:`predict_proba` to obtain probabilities for each component instead of the
+ most likely component only.
+
+ Attention:
+ When calling this function in a multi-process environment, each process receives only
+ a subset of the predictions. If you want to aggregate predictions, make sure to gather
+ the values returned from this method.
+ """
+ return self.predict_proba(data).argmax(-1)
+
+ def predict_proba(self, data: TensorLike) -> torch.Tensor:
+ """
+ Computes a distribution over the components for each of the provided datapoints.
+
+ Args:
+ data: The datapoints for which to compute the component assignment probabilities.
+
+ Returns
+ -------
+ A tensor of shape ``[num_datapoints, num_components]`` with the assignment
+ probabilities for each component and datapoint. Note that each row of the vector sums
+ to 1, i.e. the returned tensor provides a proper distribution over the components for
+ each datapoint.
+
+ Attention:
+ When calling this function in a multi-process environment, each process receives only
+ a subset of the predictions. If you want to aggregate predictions, make sure to gather
+ the values returned from this method.
+ """
+ loader = DataLoader(
+ dataset_from_tensors(data),
+ batch_size=self.batch_size or len(data),
+ collate_fn=collate_tensor,
+ )
+ result = self.trainer().predict(GaussianMixtureLightningModule(self.model_), loader)
+ return torch.cat([x[0] for x in cast(List[Tuple[torch.Tensor, torch.Tensor]], result)])
diff --git a/src/torchgmm/bayes/gmm/lightning_module.py b/src/torchgmm/bayes/gmm/lightning_module.py
new file mode 100644
index 0000000..7175381
--- /dev/null
+++ b/src/torchgmm/bayes/gmm/lightning_module.py
@@ -0,0 +1,329 @@
+from __future__ import annotations
+
+import pytorch_lightning as pl
+import torch
+from pytorch_lightning.callbacks import EarlyStopping
+from torchmetrics import MeanMetric
+
+from torchgmm.bayes.core import cholesky_precision
+from torchgmm.utils import NonparametricLightningModule
+
+from .metrics import CovarianceAggregator, MeanAggregator, PriorAggregator
+from .model import GaussianMixtureModel
+
+# -------------------------------------------------------------------------------------------------
+# TRAINING
+
+
+class GaussianMixtureLightningModule(NonparametricLightningModule):
+ """
+ Lightning module for training and evaluating a Gaussian mixture model.
+ """
+
+ def __init__(
+ self,
+ model: GaussianMixtureModel,
+ convergence_tolerance: float = 1e-3,
+ covariance_regularization: float = 1e-6,
+ is_batch_training: bool = False,
+ ):
+ """
+ Args:
+ model: The Gaussian mixture model to use for training/evaluation.
+ convergence_tolerance: The change in the per-datapoint negative log-likelihood which
+ implies that training has converged.
+ covariance_regularization: A small value which is added to the diagonal of the
+ covariance matrix to ensure that it is positive semi-definite.
+ is_batch_training: Whether training is performed on mini-batches instead of the entire
+ data at once. In the case of batching, the EM-algorithm is "split" across two
+ epochs.
+ """
+ super().__init__()
+
+ self.model = model
+ self.convergence_tolerance = convergence_tolerance
+ self.is_batch_training = is_batch_training
+
+ # For batch training, we store a model copy such that we can "replay" responsibilities
+ if self.is_batch_training:
+ self.model_copy = GaussianMixtureModel(self.model.config)
+ self.model_copy.load_state_dict(self.model.state_dict())
+
+ # Initialize aggregators
+ self.prior_aggregator = PriorAggregator(
+ num_components=self.model.config.num_components,
+ dist_sync_fn=self.all_gather,
+ )
+ self.mean_aggregator = MeanAggregator(
+ num_components=self.model.config.num_components,
+ num_features=self.model.config.num_features,
+ dist_sync_fn=self.all_gather,
+ )
+ self.covar_aggregator = CovarianceAggregator(
+ num_components=self.model.config.num_components,
+ num_features=self.model.config.num_features,
+ covariance_type=self.model.config.covariance_type,
+ reg=covariance_regularization,
+ dist_sync_fn=self.all_gather,
+ )
+
+ # Initialize metrics
+ self.metric_nll = MeanMetric(dist_sync_fn=self.all_gather)
+
+ def configure_callbacks(self) -> list[pl.Callback]:
+ if self.convergence_tolerance == 0:
+ return []
+ early_stopping = EarlyStopping(
+ "nll",
+ min_delta=self.convergence_tolerance,
+ patience=2 if self.is_batch_training else 1,
+ check_on_train_epoch_end=True,
+ strict=False, # Allows to not log every epoch
+ )
+ return [early_stopping]
+
+ def on_train_epoch_start(self) -> None:
+ self.prior_aggregator.reset()
+ self.mean_aggregator.reset()
+ self.covar_aggregator.reset()
+
+ def nonparametric_training_step(self, batch: torch.Tensor, _batch_idx: int) -> None:
+ ### E-Step
+ if self._computes_responsibilities_on_live_model:
+ log_responsibilities, log_probs = self.model.forward(batch)
+ else:
+ log_responsibilities, log_probs = self.model_copy.forward(batch)
+ responsibilities = log_responsibilities.exp()
+
+ # Compute the NLL for early stopping
+ if self._should_log_nll:
+ self.metric_nll.update(-log_probs)
+ self.log("nll", self.metric_nll, on_step=False, on_epoch=True, prog_bar=True)
+
+ ### (Partial) M-Step
+ if self._should_update_means:
+ self.prior_aggregator.update(responsibilities)
+ self.mean_aggregator.update(batch, responsibilities)
+ if self._should_update_covars:
+ means = self.mean_aggregator.compute()
+ self.covar_aggregator.update(batch, responsibilities, means)
+ else:
+ self.covar_aggregator.update(batch, responsibilities, self.model.means)
+
+ def nonparametric_training_epoch_end(self) -> None:
+ # Prior to updating the model, we might need to copy it in the case of batch training
+ if self._requires_to_copy_live_model:
+ self.model_copy.load_state_dict(self.model.state_dict())
+
+ # Finalize the M-Step
+ if self._should_update_means:
+ priors = self.prior_aggregator.compute()
+ self.model.component_probs.copy_(priors)
+
+ means = self.mean_aggregator.compute()
+ self.model.means.copy_(means)
+
+ if self._should_update_covars:
+ covars = self.covar_aggregator.compute()
+ self.model.precisions_cholesky.copy_(cholesky_precision(covars, self.model.config.covariance_type))
+
+ def test_step(self, batch: torch.Tensor, _batch_idx: int) -> None:
+ _, log_probs = self.model.forward(batch)
+ self.metric_nll.update(-log_probs)
+ self.log("nll", self.metric_nll)
+
+ def predict_step(self, batch: torch.Tensor, batch_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
+ log_responsibilities, log_probs = self.model.forward(batch)
+ return log_responsibilities.exp(), -log_probs
+
+ @property
+ def _computes_responsibilities_on_live_model(self) -> bool:
+ if not self.is_batch_training:
+ return True
+ return self.current_epoch % 2 == 0
+
+ @property
+ def _requires_to_copy_live_model(self) -> bool:
+ if not self.is_batch_training:
+ return False
+ return self.current_epoch % 2 == 0
+
+ @property
+ def _should_log_nll(self) -> bool:
+ if not self.is_batch_training:
+ return True
+ return self.current_epoch % 2 == 1
+
+ @property
+ def _should_update_means(self) -> bool:
+ if not self.is_batch_training:
+ return True
+ return self.current_epoch % 2 == 0
+
+ @property
+ def _should_update_covars(self) -> bool:
+ if not self.is_batch_training:
+ return True
+ return self.current_epoch % 2 == 1
+
+
+# -------------------------------------------------------------------------------------------------
+# INIT STRATEGIES
+
+
+class GaussianMixtureKmeansInitLightningModule(NonparametricLightningModule):
+ """
+ Lightning module for initializing a Gaussian mixture from centroids found via K-Means.
+ """
+
+ def __init__(self, model: GaussianMixtureModel, covariance_regularization: float):
+ """
+ Args:
+ model: The model whose parameters to initialize.
+ covariance_regularization: A small value which is added to the diagonal of the
+ covariance matrix to ensure that it is positive semi-definite.
+ """
+ super().__init__()
+
+ self.model = model
+
+ self.prior_aggregator = PriorAggregator(
+ num_components=self.model.config.num_components,
+ dist_sync_fn=self.all_gather,
+ )
+ self.covar_aggregator = CovarianceAggregator(
+ num_components=self.model.config.num_components,
+ num_features=self.model.config.num_features,
+ covariance_type=self.model.config.covariance_type,
+ reg=covariance_regularization,
+ dist_sync_fn=self.all_gather,
+ )
+
+ def on_train_epoch_start(self) -> None:
+ self.prior_aggregator.reset()
+ self.covar_aggregator.reset()
+
+ def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
+ # Just like for k-means, responsibilities are one-hot assignments to the clusters
+ responsibilities = _one_hot_responsibilities(batch, self.model.means)
+
+ # Then, we can update the aggregators
+ self.prior_aggregator.update(responsibilities)
+ self.covar_aggregator.update(batch, responsibilities, self.model.means)
+
+ def nonparametric_training_epoch_end(self) -> None:
+ priors = self.prior_aggregator.compute()
+ self.model.component_probs.copy_(priors)
+
+ covars = self.covar_aggregator.compute()
+ self.model.precisions_cholesky.copy_(cholesky_precision(covars, self.model.config.covariance_type))
+
+
+class GaussianMixtureRandomInitLightningModule(NonparametricLightningModule):
+ """
+ Lightning module for initializing a Gaussian mixture randomly or using the assignments for
+ arbitrary means that were not found via K-means.
+
+ For batch training, this requires two epochs, otherwise, it requires a single epoch.
+ """
+
+ def __init__(
+ self,
+ model: GaussianMixtureModel,
+ covariance_regularization: float,
+ is_batch_training: bool,
+ use_model_means: bool,
+ ):
+ """
+ Args:
+ model: The model whose parameters to initialize.
+ covariance_regularization: A small value which is added to the diagonal of the
+ covariance matrix to ensure that it is positive semi-definite.
+ is_batch_training: Whether training is performed on mini-batches instead of the entire
+ data at once.
+ use_model_means: Whether the model's means ought to be used for one-hot component
+ assignments.
+ """
+ super().__init__()
+
+ self.model = model
+ self.is_batch_training = is_batch_training
+ self.use_model_means = use_model_means
+
+ self.prior_aggregator = PriorAggregator(
+ num_components=self.model.config.num_components,
+ dist_sync_fn=self.all_gather,
+ )
+ self.mean_aggregator = MeanAggregator(
+ num_components=self.model.config.num_components,
+ num_features=self.model.config.num_features,
+ dist_sync_fn=self.all_gather,
+ )
+ self.covar_aggregator = CovarianceAggregator(
+ num_components=self.model.config.num_components,
+ num_features=self.model.config.num_features,
+ covariance_type=self.model.config.covariance_type,
+ reg=covariance_regularization,
+ dist_sync_fn=self.all_gather,
+ )
+
+ # For batch training, we store a model copy such that we can "replay" responsibilities
+ if self.is_batch_training and self.use_model_means:
+ self.model_copy = GaussianMixtureModel(self.model.config)
+ self.model_copy.load_state_dict(self.model.state_dict())
+
+ def on_train_epoch_start(self) -> None:
+ self.prior_aggregator.reset()
+ self.mean_aggregator.reset()
+ self.covar_aggregator.reset()
+
+ def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
+ if self.use_model_means:
+ if self.current_epoch == 0:
+ responsibilities = _one_hot_responsibilities(batch, self.model.means)
+ else:
+ responsibilities = _one_hot_responsibilities(batch, self.model_copy.means)
+ else:
+ responsibilities = torch.rand(
+ batch.size(0),
+ self.model.config.num_components,
+ device=batch.device,
+ dtype=batch.dtype,
+ )
+ responsibilities = responsibilities / responsibilities.sum(1, keepdim=True)
+
+ if self.current_epoch == 0:
+ self.prior_aggregator.update(responsibilities)
+ self.mean_aggregator.update(batch, responsibilities)
+ if not self.is_batch_training:
+ means = self.mean_aggregator.compute()
+ self.covar_aggregator.update(batch, responsibilities, means)
+ else:
+ # Only reached if batch training
+ self.covar_aggregator.update(batch, responsibilities, self.model.means)
+
+ def nonparametric_training_epoch_end(self) -> None:
+ if self.current_epoch == 0 and self.is_batch_training:
+ self.model_copy.load_state_dict(self.model.state_dict())
+
+ if self.current_epoch == 0:
+ priors = self.prior_aggregator.compute()
+ self.model.component_probs.copy_(priors)
+
+ means = self.mean_aggregator.compute()
+ self.model.means.copy_(means)
+
+ if (self.current_epoch == 0 and not self.is_batch_training) or self.current_epoch == 1:
+ covars = self.covar_aggregator.compute()
+ self.model.precisions_cholesky.copy_(cholesky_precision(covars, self.model.config.covariance_type))
+
+
+def _one_hot_responsibilities(data: torch.Tensor, centroids: torch.Tensor) -> torch.Tensor:
+ distances = torch.cdist(data, centroids)
+ assignments = distances.min(1).indices
+ onehot = torch.eye(
+ centroids.size(0),
+ device=data.device,
+ dtype=data.dtype,
+ )
+ return onehot[assignments]
diff --git a/src/torchgmm/bayes/gmm/metrics.py b/src/torchgmm/bayes/gmm/metrics.py
new file mode 100644
index 0000000..3fc14a5
--- /dev/null
+++ b/src/torchgmm/bayes/gmm/metrics.py
@@ -0,0 +1,146 @@
+from typing import Any, Callable, Optional
+
+import torch
+from torchmetrics import Metric
+
+from torchgmm.bayes.core import CovarianceType, covariance_shape
+
+
+class PriorAggregator(Metric):
+ """
+ The prior aggregator aggregates component probabilities over batches and process.
+ """
+
+ full_state_update = False
+
+ def __init__(
+ self,
+ num_components: int,
+ *,
+ dist_sync_fn: Optional[Callable[[Any], Any]] = None,
+ ):
+ super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
+
+ self.responsibilities: torch.Tensor
+ self.add_state("responsibilities", torch.zeros(num_components), dist_reduce_fx="sum")
+
+ def update(self, responsibilities: torch.Tensor) -> None:
+ # Responsibilities have shape [N, K]
+ self.responsibilities.add_(responsibilities.sum(0))
+
+ def compute(self) -> torch.Tensor:
+ return self.responsibilities / self.responsibilities.sum()
+
+
+class MeanAggregator(Metric):
+ """
+ The mean aggregator aggregates component means over batches and processes.
+ """
+
+ full_state_update = False
+
+ def __init__(
+ self,
+ num_components: int,
+ num_features: int,
+ *,
+ dist_sync_fn: Optional[Callable[[Any], Any]] = None,
+ ):
+ super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
+
+ self.mean_sum: torch.Tensor
+ self.add_state("mean_sum", torch.zeros(num_components, num_features), dist_reduce_fx="sum")
+
+ self.component_weights: torch.Tensor
+ self.add_state("component_weights", torch.zeros(num_components), dist_reduce_fx="sum")
+
+ def update(self, data: torch.Tensor, responsibilities: torch.Tensor) -> None:
+ # Data has shape [N, D]
+ # Responsibilities have shape [N, K]
+ self.mean_sum.add_(responsibilities.t().matmul(data))
+ self.component_weights.add_(responsibilities.sum(0))
+
+ def compute(self) -> torch.Tensor:
+ return self.mean_sum / self.component_weights.unsqueeze(1)
+
+
+class CovarianceAggregator(Metric):
+ """
+ The covariance aggregator aggregates component covariances over batches and processes.
+ """
+
+ full_state_update = False
+
+ def __init__(
+ self,
+ num_components: int,
+ num_features: int,
+ covariance_type: CovarianceType,
+ reg: float,
+ *,
+ dist_sync_fn: Optional[Callable[[Any], Any]] = None,
+ ):
+ super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
+
+ self.num_components = num_components
+ self.num_features = num_features
+ self.covariance_type = covariance_type
+ self.reg = reg
+
+ self.covariance_sum: torch.Tensor
+ self.add_state(
+ "covariance_sum",
+ torch.zeros(covariance_shape(num_components, num_features, covariance_type)),
+ dist_reduce_fx="sum",
+ )
+
+ self.component_weights: torch.Tensor
+ self.add_state("component_weights", torch.zeros(num_components), dist_reduce_fx="sum")
+
+ def update(self, data: torch.Tensor, responsibilities: torch.Tensor, means: torch.Tensor) -> None:
+ data_component_weights = responsibilities.sum(0)
+ self.component_weights.add_(data_component_weights)
+
+ if self.covariance_type in ("spherical", "diag"):
+ x_prob = torch.matmul(responsibilities.t(), data.square())
+ m_prob = data_component_weights.unsqueeze(-1) * means.square()
+ xm_prob = means * torch.matmul(responsibilities.t(), data)
+ covars = x_prob - 2 * xm_prob + m_prob
+ if self.covariance_type == "diag":
+ self.covariance_sum.add_(covars)
+ else: # covariance_type == "spherical"
+ self.covariance_sum.add_(covars.mean(1))
+ elif self.covariance_type == "tied":
+ # This is taken from https://github.com/scikit-learn/scikit-learn/blob/
+ # 844b4be24d20fc42cc13b957374c718956a0db39/sklearn/mixture/_gaussian_mixture.py#L183
+ x_sq = data.T.matmul(data)
+ mean_sq = (data_component_weights * means.T).matmul(means)
+ self.covariance_sum.add_(x_sq - mean_sq)
+ else: # covariance_type == "full":
+ # We iterate over each component since this is typically faster...
+ for i in range(self.num_components):
+ component_diff = data - means[i]
+ covars = (responsibilities[:, i].unsqueeze(1) * component_diff).T.matmul(component_diff)
+ self.covariance_sum[i].add_(covars)
+
+ def compute(self) -> torch.Tensor:
+ if self.covariance_type == "diag":
+ return self.covariance_sum / self.component_weights.unsqueeze(-1) + self.reg
+ if self.covariance_type == "spherical":
+ return self.covariance_sum / self.component_weights + self.reg * self.num_features
+ if self.covariance_type == "tied":
+ result = self.covariance_sum / self.component_weights.sum()
+ shape = result.size()
+ result = result.flatten()
+ result[:: self.num_features + 1].add_(self.reg)
+ return result.view(shape)
+ # covariance_type == "full"
+ result = self.covariance_sum / self.component_weights.unsqueeze(-1).unsqueeze(-1)
+ diag_mask = (
+ torch.eye(self.num_features, device=result.device, dtype=result.dtype)
+ .bool()
+ .unsqueeze(0)
+ .expand(self.num_components, -1, -1)
+ )
+ result[diag_mask] += self.reg
+ return result
diff --git a/src/torchgmm/bayes/gmm/model.py b/src/torchgmm/bayes/gmm/model.py
new file mode 100644
index 0000000..37b205a
--- /dev/null
+++ b/src/torchgmm/bayes/gmm/model.py
@@ -0,0 +1,148 @@
+from dataclasses import dataclass
+from typing import Tuple
+
+import numpy as np
+import torch
+from torch import jit, nn
+
+from torchgmm.base.nn import Configurable
+from torchgmm.bayes.core import CovarianceType, covariance, covariance_shape
+from torchgmm.bayes.core._jit import jit_log_normal, jit_sample_normal
+
+
+@dataclass
+class GaussianMixtureModelConfig:
+ """
+ Configuration class for a Gaussian mixture model.
+
+ See Also
+ --------
+ :class:`GaussianMixtureModel`
+ """
+
+ #: The number of components in the GMM.
+ num_components: int
+ #: The number of features for the GMM's components.
+ num_features: int
+ #: The type of covariance to use for the components.
+ covariance_type: CovarianceType
+
+
+class GaussianMixtureModel(Configurable[GaussianMixtureModelConfig], nn.Module):
+ """
+ PyTorch module for a Gaussian mixture model.
+
+ Covariances are represented via their Cholesky decomposition for computational efficiency. The
+ model does not have trainable parameters.
+ """
+
+ #: The probabilities of each component, buffer of shape ``[num_components]``.
+ component_probs: torch.Tensor
+ #: The means of each component, buffer of shape ``[num_components, num_features]``.
+ means: torch.Tensor
+ #: The precision matrices for the components' covariances, buffer with a shape dependent
+ #: on the covariance type, see :class:`CovarianceType`.
+ precisions_cholesky: torch.Tensor
+
+ def __init__(self, config: GaussianMixtureModelConfig):
+ """
+ Args:
+ config: The configuration to use for initializing the module's buffers.
+ """
+ super().__init__(config)
+
+ self.covariance_type = config.covariance_type
+
+ self.register_buffer("component_probs", torch.empty(config.num_components))
+ self.register_buffer("means", torch.empty(config.num_components, config.num_features))
+
+ shape = covariance_shape(config.num_components, config.num_features, config.covariance_type)
+ self.register_buffer("precisions_cholesky", torch.empty(shape))
+
+ self.reset_parameters()
+
+ @jit.unused # type: ignore
+ @property
+ def covariances(self) -> torch.Tensor:
+ """
+ The covariance matrices learnt for the GMM's components.
+
+ The shape of the tensor depends on the covariance type, see :class:`CovarianceType`.
+ """
+ return covariance(self.precisions_cholesky, self.covariance_type) # type: ignore
+
+ @jit.unused
+ def reset_parameters(self) -> None:
+ """
+ Resets the parameters of the GMM.
+
+ - Component probabilities are initialized via uniform sampling and normalization.
+ - Means are initialized randomly from a Standard Normal.
+ - Cholesky precisions are initialized randomly based on the covariance type. For all
+ covariance types, it is based on uniform sampling.
+ """
+ nn.init.uniform_(self.component_probs)
+ self.component_probs.div_(self.component_probs.sum())
+
+ nn.init.normal_(self.means)
+
+ nn.init.uniform_(self.precisions_cholesky)
+ if self.covariance_type in ("full", "tied"):
+ self.precisions_cholesky.tril_()
+
+ def forward(self, data: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Computes the log-probability of observing each of the provided datapoints for each of the
+ GMM's components.
+
+ Args:
+ data: A tensor of shape ``[num_datapoints, num_features]`` for which to compute the
+ log-probabilities.
+
+ Returns
+ -------
+ - A tensor of shape ``[num_datapoints, num_components]`` with the log-responsibilities
+ for each datapoint and components. These are the logits of the Categorical
+ distribution over the parameters.
+ - A tensor of shape ``[num_datapoints]`` with the log-likelihood of each datapoint.
+ """
+ log_probabilities = jit_log_normal(data, self.means, self.precisions_cholesky, self.covariance_type)
+ log_responsibilities = log_probabilities + self.component_probs.log()
+ log_prob = log_responsibilities.logsumexp(1, keepdim=True)
+ return log_responsibilities - log_prob, log_prob.squeeze(1)
+
+ def sample(self, num_datapoints: int) -> torch.Tensor:
+ """
+ Samples the provided number of datapoints from the GMM.
+
+ Args:
+ num_datapoints: The number of datapoints to sample.
+
+ Returns
+ -------
+ A tensor of shape ``[num_datapoints, num_features]`` with the random samples.
+
+ Attention:
+ This method does not automatically perform batching. If you need to sample many
+ datapoints, call this method multiple times.
+ """
+ # First, we sample counts for each
+ component_counts = np.random.multinomial(num_datapoints, self.component_probs.numpy())
+
+ # Then, we generate datapoints for each components
+ result = []
+ for i, count in enumerate(component_counts):
+ sample = jit_sample_normal(
+ count.item(),
+ self.means[i],
+ self._get_component_precision(i),
+ self.covariance_type,
+ )
+ result.append(sample)
+
+ return torch.cat(result, dim=0)
+
+ def _get_component_precision(self, component: int) -> torch.Tensor:
+ if self.covariance_type == "tied":
+ return self.precisions_cholesky
+ return self.precisions_cholesky[component]
diff --git a/src/torchgmm/bayes/gmm/types.py b/src/torchgmm/bayes/gmm/types.py
new file mode 100644
index 0000000..579b424
--- /dev/null
+++ b/src/torchgmm/bayes/gmm/types.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+from typing import Literal
+
+GaussianMixtureInitStrategy = Literal["random", "kmeans", "kmeans++"]
+GaussianMixtureInitStrategy.__doc__ = """
+Strategy for initializing the parameters of a Gaussian mixture model.
+
+- **random**: Samples responsibilities of datapoints at random and subsequently initializes means
+ and covariances from these.
+- **kmeans**: Runs K-Means via :class:`torchgmm.clustering.KMeans` and uses the centroids as the
+ initial component means. For computing the covariances, responsibilities are given as the
+ one-hot cluster assignments.
+- **kmeans++**: Runs only the K-Means++ initialization procedure to sample means in a smart
+ fashion. Might be more efficient than ``kmeans`` as it does not actually run clustering. For
+ many clusters, this is, however, still slow.
+"""
diff --git a/src/torchgmm/clustering/__init__.py b/src/torchgmm/clustering/__init__.py
new file mode 100644
index 0000000..f6f7385
--- /dev/null
+++ b/src/torchgmm/clustering/__init__.py
@@ -0,0 +1,5 @@
+from .kmeans import KMeans
+
+__all__ = [
+ "KMeans",
+]
diff --git a/src/torchgmm/clustering/kmeans/__init__.py b/src/torchgmm/clustering/kmeans/__init__.py
new file mode 100644
index 0000000..e5b77c7
--- /dev/null
+++ b/src/torchgmm/clustering/kmeans/__init__.py
@@ -0,0 +1,8 @@
+from .estimator import KMeans
+from .model import KMeansModel, KMeansModelConfig
+
+__all__ = [
+ "KMeans",
+ "KMeansModel",
+ "KMeansModelConfig",
+]
diff --git a/src/torchgmm/clustering/kmeans/estimator.py b/src/torchgmm/clustering/kmeans/estimator.py
new file mode 100644
index 0000000..00bdd93
--- /dev/null
+++ b/src/torchgmm/clustering/kmeans/estimator.py
@@ -0,0 +1,267 @@
+from __future__ import annotations
+
+import logging
+from typing import Any, List, cast
+
+import numpy as np
+import torch
+
+from torchgmm.base import ConfigurableBaseEstimator
+from torchgmm.base.data import DataLoader, TensorLike, collate_tensor, dataset_from_tensors
+from torchgmm.base.estimator import PredictorMixin, TransformerMixin
+
+from .lightning_module import (
+ FeatureVarianceLightningModule,
+ KMeansLightningModule,
+ KmeansPlusPlusInitLightningModule,
+ KmeansRandomInitLightningModule,
+)
+from .model import KMeansModel, KMeansModelConfig
+from .types import KMeansInitStrategy
+
+logger = logging.getLogger(__name__)
+
+
+class KMeans(
+ ConfigurableBaseEstimator[KMeansModel], # type: ignore
+ TransformerMixin[TensorLike, torch.Tensor],
+ PredictorMixin[TensorLike, torch.Tensor],
+):
+ """
+ Model for clustering data into a predefined number of clusters. More information on K-means
+ clustering is available on `Wikipedia `_.
+
+ See Also
+ --------
+ .. currentmodule:: torchgmm.clustering.kmeans
+ .. autosummary::
+ :nosignatures:
+ :template: classes/pytorch_module.rst
+
+ KMeansModel
+ KMeansModelConfig
+ """
+
+ #: The fitted PyTorch module with all estimated parameters.
+ model_: KMeansModel
+ #: A boolean indicating whether the model converged during training.
+ converged_: bool
+ #: The number of iterations the model was fitted for, excluding initialization.
+ num_iter_: int
+ #: The mean squared distance of all datapoints to their closest cluster centers.
+ inertia_: float
+
+ def __init__(
+ self,
+ num_clusters: int = 1,
+ *,
+ init_strategy: KMeansInitStrategy = "kmeans++",
+ convergence_tolerance: float = 1e-4,
+ batch_size: int | None = None,
+ trainer_params: dict[str, Any] | None = None,
+ ):
+ """
+ Args:
+ num_clusters: The number of clusters.
+ init_strategy: The strategy for initializing centroids.
+ convergence_tolerance: Training is conducted until the Frobenius norm of the change
+ between cluster centroids falls below this threshold. The tolerance is multiplied
+ by the average variance of the features.
+ batch_size: The batch size to use when fitting the model. If not provided, the full
+ data will be used as a single batch. Set this if the full data does not fit into
+ memory.
+ trainer_params: Initialization parameters to use when initializing a PyTorch Lightning
+ trainer. By default, it disables various stdout logs unless TorchGMM is configured to
+ do verbose logging. Checkpointing and logging are disabled regardless of the log
+ level. This estimator further sets the following overridable defaults:
+
+ - ``max_epochs=300``
+
+ Note:
+ The number of epochs passed to the initializer only define the number of optimization
+ epochs. Prior to that, initialization is run which may perform additional iterations
+ through the data.
+ """
+ super().__init__(
+ default_params={"max_epochs": 300},
+ user_params=trainer_params,
+ )
+
+ # Assign other properties
+ self.batch_size = batch_size
+ self.num_clusters = num_clusters
+ self.init_strategy = init_strategy
+ self.convergence_tolerance = convergence_tolerance
+
+ def fit(self, data: TensorLike) -> KMeans:
+ """
+ Fits the KMeans model on the provided data by running Lloyd's algorithm.
+
+ Args:
+ data: The tabular data to fit on. The dimensionality of the KMeans model is
+ automatically inferred from this data.
+
+ Returns
+ -------
+ The fitted KMeans model.
+ """
+ if (data.dtype in [torch.float64, np.float64]) and (self.trainer_params.get("precision", 32) == 32):
+ raise ValueError(
+ "Data is of type float64. Transform it to float32 or use trainer_params={'precision': 64}."
+ )
+ # Initialize model
+ num_features = len(data[0])
+ config = KMeansModelConfig(
+ num_clusters=self.num_clusters,
+ num_features=num_features,
+ )
+ self.model_ = KMeansModel(config)
+
+ # Setup the data loading
+ loader = DataLoader(
+ dataset_from_tensors(data),
+ batch_size=self.batch_size or len(data),
+ collate_fn=collate_tensor,
+ )
+ is_batch_training = self._num_batches_per_epoch(loader) > 1
+
+ # First, initialize the centroids
+ if self.init_strategy == "random":
+ module = KmeansRandomInitLightningModule(self.model_)
+ num_epochs = 1
+ else:
+ module = KmeansPlusPlusInitLightningModule(
+ self.model_,
+ is_batch_training=is_batch_training,
+ )
+ num_epochs = 2 * config.num_clusters - 1
+
+ logger.info("Running initialization...")
+ self.trainer(max_epochs=num_epochs).fit(module, loader)
+
+ # Then, in order to find the right convergence tolerance, we need to compute the variance
+ # of the data.
+ if self.convergence_tolerance != 0:
+ variances = torch.empty(config.num_features)
+ module = FeatureVarianceLightningModule(variances)
+ self.trainer().fit(module, loader)
+
+ tolerance_multiplier = cast(float, variances.mean().item())
+ convergence_tolerance = self.convergence_tolerance * tolerance_multiplier
+ else:
+ convergence_tolerance = 0
+
+ # Then, we can fit the actual model. We need a new trainer for that
+ logger.info("Fitting K-Means...")
+ trainer = self.trainer()
+ module = KMeansLightningModule(
+ self.model_,
+ convergence_tolerance=convergence_tolerance,
+ )
+ trainer.fit(module, loader)
+
+ # Assign convergence properties
+ self.num_iter_ = module.current_epoch
+ self.converged_ = module.current_epoch < trainer.max_epochs
+ if "inertia" in trainer.callback_metrics:
+ self.inertia_ = cast(float, trainer.callback_metrics["inertia"].item())
+ return self
+
+ def predict(self, data: TensorLike) -> torch.Tensor:
+ """
+ Predicts the closest cluster for each item provided.
+
+ Args:
+ data: The datapoints for which to predict the clusters.
+
+ Returns
+ -------
+ Tensor of shape ``[num_datapoints]`` with the index of the closest cluster for each
+ datapoint.
+
+ Attention:
+ When calling this function in a multi-process environment, each process receives only
+ a subset of the predictions. If you want to aggregate predictions, make sure to gather
+ the values returned from this method.
+ """
+ loader = DataLoader(
+ dataset_from_tensors(data),
+ batch_size=self.batch_size or len(data),
+ collate_fn=collate_tensor,
+ )
+ result = self.trainer().predict(KMeansLightningModule(self.model_, predict_target="assignments"), loader)
+ return torch.cat(cast(List[torch.Tensor], result))
+
+ def score(self, data: TensorLike) -> float:
+ """
+ Computes the average inertia of all the provided datapoints. That is, it computes the mean
+ squared distance to each datapoint's closest centroid.
+
+ Args:
+ data: The data for which to compute the average inertia.
+
+ Returns
+ -------
+ The average inertia.
+
+ Note:
+ See :meth:`score_samples` to obtain the inertia for individual sequences.
+ """
+ loader = DataLoader(
+ dataset_from_tensors(data),
+ batch_size=self.batch_size or len(data),
+ collate_fn=collate_tensor,
+ )
+ result = self.trainer().test(KMeansLightningModule(self.model_), loader, verbose=False)
+ return result[0]["inertia"]
+
+ def score_samples(self, data: TensorLike) -> torch.Tensor:
+ """
+ Computes the inertia for each of the the provided datapoints. That is, it computes the mean
+ squared distance of each datapoint to its closest centroid.
+
+ Args:
+ data: The data for which to compute the inertia values.
+
+ Returns
+ -------
+ A tensor of shape ``[num_datapoints]`` with the inertia of each datapoint.
+
+ Attention:
+ When calling this function in a multi-process environment, each process receives only
+ a subset of the predictions. If you want to aggregate predictions, make sure to gather
+ the values returned from this method.
+ """
+ loader = DataLoader(
+ dataset_from_tensors(data),
+ batch_size=self.batch_size or len(data),
+ collate_fn=collate_tensor,
+ )
+ result = self.trainer().predict(KMeansLightningModule(self.model_, predict_target="inertias"), loader)
+ return torch.cat(cast(List[torch.Tensor], result))
+
+ def transform(self, data: TensorLike) -> torch.Tensor:
+ """
+ Transforms the provided data into the cluster-distance space. That is, it returns the
+ distance of each datapoint to each cluster centroid.
+
+ Args:
+ data: The data to transform.
+
+ Returns
+ -------
+ A tensor of shape ``[num_datapoints, num_clusters]`` with the distances to the cluster
+ centroids.
+
+ Attention:
+ When calling this function in a multi-process environment, each process receives only
+ a subset of the predictions. If you want to aggregate predictions, make sure to gather
+ the values returned from this method.
+ """
+ loader = DataLoader(
+ dataset_from_tensors(data),
+ batch_size=self.batch_size or len(data),
+ collate_fn=collate_tensor,
+ )
+ result = self.trainer().predict(KMeansLightningModule(self.model_, predict_target="distances"), loader)
+ return torch.cat(cast(List[torch.Tensor], result))
diff --git a/src/torchgmm/clustering/kmeans/lightning_module.py b/src/torchgmm/clustering/kmeans/lightning_module.py
new file mode 100644
index 0000000..c4e969d
--- /dev/null
+++ b/src/torchgmm/clustering/kmeans/lightning_module.py
@@ -0,0 +1,308 @@
+# pylint: disable=abstract-method
+import math
+from typing import List, Literal
+
+import pytorch_lightning as pl
+import torch
+from pytorch_lightning.callbacks import EarlyStopping
+from torchmetrics import MeanMetric
+
+from torchgmm.utils import NonparametricLightningModule
+
+from .metrics import (
+ BatchAverager,
+ BatchSummer,
+ CentroidAggregator,
+ DistanceSampler,
+ UniformSampler,
+)
+from .model import KMeansModel
+
+# -------------------------------------------------------------------------------------------------
+# TRAINING
+
+
+class KMeansLightningModule(NonparametricLightningModule):
+ """
+ Lightning module for training and evaluating a K-Means model.
+ """
+
+ def __init__(
+ self,
+ model: KMeansModel,
+ convergence_tolerance: float = 1e-4,
+ predict_target: Literal["assignments", "distances", "inertias"] = "assignments",
+ ):
+ """
+ Args:
+ model: The model to train.
+ convergence_tolerance: Training is conducted until the Frobenius norm of the change
+ between cluster centroids falls below this threshold.
+ predict_target: Whether to predict cluster assigments or distances to clusters.
+ """
+ super().__init__()
+
+ self.model = model
+ self.convergence_tolerance = convergence_tolerance
+ self.predict_target = predict_target
+
+ # Initialize aggregators
+ self.centroid_aggregator = CentroidAggregator(
+ num_clusters=self.model.config.num_clusters,
+ num_features=self.model.config.num_features,
+ dist_sync_fn=self.all_gather,
+ )
+
+ # Initialize metrics
+ self.metric_inertia = MeanMetric()
+
+ def configure_callbacks(self) -> List[pl.Callback]:
+ if self.convergence_tolerance == 0:
+ return []
+ early_stopping = EarlyStopping(
+ "frobenius_norm_change",
+ patience=100000,
+ stopping_threshold=self.convergence_tolerance,
+ check_on_train_epoch_end=True,
+ )
+ return [early_stopping]
+
+ def on_train_epoch_start(self) -> None:
+ self.centroid_aggregator.reset()
+
+ def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
+ # First, we compute the cluster assignments
+ _, assignments, inertias = self.model.forward(batch)
+
+ # Then, we update the centroids
+ self.centroid_aggregator.update(batch, assignments)
+
+ # And log the inertia
+ self.metric_inertia.update(inertias)
+ self.log("inertia", self.metric_inertia, on_step=False, on_epoch=True, prog_bar=True)
+
+ def nonparametric_training_epoch_end(self) -> None:
+ centroids = self.centroid_aggregator.compute()
+ self.log("frobenius_norm_change", torch.linalg.norm(self.model.centroids - centroids))
+ self.model.centroids.copy_(centroids)
+
+ def test_step(self, batch: torch.Tensor, _batch_idx: int) -> None:
+ _, _, inertias = self.model.forward(batch)
+ self.metric_inertia.update(inertias)
+ self.log("inertia", self.metric_inertia)
+
+ def predict_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
+ distances, assignments, inertias = self.model.forward(batch)
+ if self.predict_target == "assignments":
+ return assignments
+ if self.predict_target == "inertias":
+ return inertias
+ return distances
+
+
+# -------------------------------------------------------------------------------------------------
+# INIT STRATEGIES
+
+
+class KmeansRandomInitLightningModule(NonparametricLightningModule):
+ """
+ Lightning module for initializing K-Means centroids randomly.
+
+ Within the first epoch, all items are sampled. Thus, this module should only be trained for a
+ single epoch.
+ """
+
+ def __init__(self, model: KMeansModel):
+ """
+ Args:
+ model: The model to initialize.
+ """
+ super().__init__()
+
+ self.model = model
+
+ self.sampler = UniformSampler(
+ num_choices=self.model.config.num_clusters,
+ num_features=self.model.config.num_features,
+ dist_sync_fn=self.all_gather_first,
+ )
+
+ def on_train_epoch_start(self) -> None:
+ self.sampler.reset()
+
+ def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
+ self.sampler.update(batch)
+
+ def nonparametric_training_epoch_end(self) -> None:
+ choices = self.sampler.compute()
+ self.model.centroids.copy_(choices)
+
+
+class KmeansPlusPlusInitLightningModule(NonparametricLightningModule):
+ """
+ Lightning module for K-Means++ initialization. It performs the following operations:
+
+ - In the first epoch, a centroid is chosen at random.
+ - In even epochs, candidates for the next centroid are sampled, based on the squared distance
+ to their nearest cluster center.
+ - In odd epochs, a candidate is selected deterministically as the next centroid.
+
+ In total, initialization thus requires ``2 * k - 1`` epochs where ``k`` is the number of
+ clusters.
+ """
+
+ def __init__(self, model: KMeansModel, is_batch_training: bool):
+ """
+ Args:
+ model: The model to initialize.
+ is_batch_training: Whether training is performed on mini-batches instead of the entire
+ data at once.
+ """
+ super().__init__()
+
+ self.model = model
+ self.is_batch_training = is_batch_training
+
+ self.uniform_sampler = UniformSampler(
+ num_choices=1,
+ num_features=self.model.config.num_features,
+ dist_sync_fn=self.all_gather_first,
+ )
+ num_candidates = 2 + int(math.log(self.model.config.num_clusters))
+ self.distance_sampler = DistanceSampler(
+ num_choices=num_candidates,
+ num_features=self.model.config.num_features,
+ dist_sync_fn=self.all_gather_first,
+ )
+ self.candidate_inertia_summer = BatchSummer(
+ num_candidates,
+ dist_sync_fn=self.all_gather,
+ )
+
+ # Some buffers required for running initialization
+ self.centroid_candidates: torch.Tensor
+ self.register_buffer(
+ "centroid_candidates",
+ torch.empty(num_candidates, self.model.config.num_features),
+ persistent=False,
+ )
+
+ if not self.is_batch_training:
+ self.shortest_distance_cache: torch.Tensor
+ self.register_buffer("shortest_distance_cache", torch.empty(1), persistent=False)
+
+ def on_train_epoch_start(self) -> None:
+ if self.current_epoch == 0:
+ self.uniform_sampler.reset()
+ elif self._is_current_epoch_sampling:
+ self.distance_sampler.reset()
+ else:
+ self.candidate_inertia_summer.reset()
+
+ def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
+ if self.current_epoch == 0:
+ self.uniform_sampler.update(batch)
+ return
+ # In all other epochs, we either sample a number of candidates from the remaining
+ # datapoints or select a candidate deterministically. In any case, the shortest
+ # distance is required.
+ if self.current_epoch == 1:
+ # In the first epoch, we can skip any argmin as the shortest distances are computed
+ # with respect to the first centroid.
+ shortest_distances = torch.cdist(batch, self.model.centroids[:1]).squeeze(1)
+ if not self.is_batch_training:
+ self.shortest_distance_cache = shortest_distances
+ elif self.is_batch_training:
+ # For batch training, we always need to recompute all distances since we can't
+ # cache them (this is the whole reason for batch training).
+ distances = torch.cdist(batch, self.model.centroids[: self._init_epoch + 1])
+ shortest_distances = distances.gather(
+ 1,
+ distances.min(1, keepdim=True).indices, # min is faster than argmin on CPU
+ ).squeeze(1)
+ else:
+ # If we're not doing batch training, we only need to compute the distance to the
+ # newest centroid (and only if we're currently sampling)
+ if self._is_current_epoch_sampling:
+ latest_distance = torch.cdist(batch, self.model.centroids[self._init_epoch - 1].unsqueeze(0)).squeeze(1)
+ shortest_distances = torch.minimum(self.shortest_distance_cache, latest_distance)
+ self.shortest_distance_cache = shortest_distances
+ else:
+ shortest_distances = self.shortest_distance_cache
+
+ if self._is_current_epoch_sampling:
+ # After computing the shortest distances, we can finally do the sampling
+ self.distance_sampler.update(batch, shortest_distances)
+ else:
+ # Or, we select a candidate by the lowest resulting inertia
+ distances = torch.cdist(batch, self.centroid_candidates)
+ updated_distances = torch.minimum(distances, shortest_distances.unsqueeze(1))
+ self.candidate_inertia_summer.update(updated_distances)
+
+ def nonparametric_training_epoch_end(self) -> None:
+ if self.current_epoch == 0:
+ choice = self.uniform_sampler.compute()
+ self.model.centroids[0].copy_(choice[0] if choice.dim() > 0 else choice)
+ elif self._is_current_epoch_sampling:
+ candidates = self.distance_sampler.compute()
+ self.centroid_candidates.copy_(candidates)
+ else:
+ new_inertias = self.candidate_inertia_summer.compute()
+ choice = new_inertias.argmin()
+ self.model.centroids[self._init_epoch].copy_(self.centroid_candidates[choice])
+
+ @property
+ def _init_epoch(self) -> int:
+ return (self.current_epoch + 1) // 2
+
+ @property
+ def _is_current_epoch_sampling(self) -> bool:
+ return self.current_epoch % 2 == 1
+
+
+# -------------------------------------------------------------------------------------------------
+# MISC
+
+
+class FeatureVarianceLightningModule(NonparametricLightningModule):
+ """
+ Lightning module for computing the average variance of a dataset's features.
+
+ In the first epoch, it computes the features' means, then it can compute their variances.
+ """
+
+ def __init__(self, variances: torch.Tensor):
+ """
+ Args:
+ variances: The output tensor where the variances are stored.
+ """
+ super().__init__()
+
+ self.mean_aggregator = BatchAverager(
+ num_values=variances.size(0),
+ for_variance=False,
+ dist_sync_fn=self.all_gather,
+ )
+ self.variance_aggregator = BatchAverager(
+ num_values=variances.size(0),
+ for_variance=True,
+ dist_sync_fn=self.all_gather,
+ )
+
+ self.means: torch.Tensor
+ self.register_buffer("means", torch.empty(variances.size(0)), persistent=False)
+
+ self.variances: torch.Tensor
+ self.register_buffer("variances", variances, persistent=False)
+
+ def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
+ if self.current_epoch == 0:
+ self.mean_aggregator.update(batch)
+ else:
+ self.variance_aggregator.update((batch - self.means.unsqueeze(0)).square())
+
+ def nonparametric_training_epoch_end(self) -> None:
+ if self.current_epoch == 0:
+ self.means.copy_(self.mean_aggregator.compute())
+ else:
+ self.variances.copy_(self.variance_aggregator.compute())
diff --git a/src/torchgmm/clustering/kmeans/metrics.py b/src/torchgmm/clustering/kmeans/metrics.py
new file mode 100644
index 0000000..dd05ee3
--- /dev/null
+++ b/src/torchgmm/clustering/kmeans/metrics.py
@@ -0,0 +1,218 @@
+import random
+from typing import Any, Callable, Optional
+
+import torch
+from torchmetrics import Metric
+
+
+class CentroidAggregator(Metric):
+ """
+ The centroid aggregator aggregates kmeans centroids over batches and processes.
+ """
+
+ full_state_update = False
+
+ def __init__(
+ self,
+ num_clusters: int,
+ num_features: int,
+ *,
+ dist_sync_fn: Optional[Callable[[Any], Any]] = None,
+ ):
+ super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
+
+ self.num_clusters = num_clusters
+ self.num_features = num_features
+
+ self.centroids: torch.Tensor
+ self.add_state("centroids", torch.zeros(num_clusters, num_features), dist_reduce_fx="sum")
+
+ self.cluster_counts: torch.Tensor
+ self.add_state("cluster_counts", torch.zeros(num_clusters), dist_reduce_fx="sum")
+
+ def update(self, data: torch.Tensor, assignments: torch.Tensor) -> None:
+ """Update the centroids with the data and assignments."""
+ indices = assignments.unsqueeze(1).expand(-1, self.num_features)
+ self.centroids.scatter_add_(0, indices, data.float())
+
+ counts = assignments.bincount(minlength=self.num_clusters).float()
+ self.cluster_counts.add_(counts)
+
+ def compute(self) -> torch.Tensor:
+ """Compute the centroids."""
+ return self.centroids / self.cluster_counts.unsqueeze(-1)
+
+
+class UniformSampler(Metric):
+ """
+ The uniform sampler randomly samples a specified number of datapoints uniformly from all
+ datapoints.
+
+ The idea is the following: sample the number of choices from each batch and track the number of
+ datapoints that was already sampled from. When sampling from the union of existing choices and
+ a new batch, more weight is put on the existing choices (according to the number of datapoints
+ they were already sampled from).
+ """
+
+ full_state_update = False
+
+ def __init__(
+ self,
+ num_choices: int,
+ num_features: int,
+ *,
+ dist_sync_fn: Optional[Callable[[Any], Any]] = None,
+ ):
+ super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
+
+ self.num_choices = num_choices
+
+ self.choices: torch.Tensor
+ self.add_state("choices", torch.empty(num_choices, num_features), dist_reduce_fx="cat")
+
+ self.choice_weights: torch.Tensor
+ self.add_state("choice_weights", torch.zeros(num_choices), dist_reduce_fx="cat")
+
+ def update(self, data: torch.Tensor) -> None:
+ if self.num_choices == 1:
+ # If there is only one choice, the fastest thing is to use the `random` package. The
+ # cumulative weight of the data is its size, the cumulative weight of the current
+ # choice is some value.
+ cum_weight = data.size(0) + self.choice_weights.item()
+ if random.random() * cum_weight < data.size(0):
+ # Use some item from the data, else keep the current choice
+ self.choices.copy_(data[random.randrange(data.size(0))])
+ else:
+ # The choices are computed from scratch every time, weighting the current choices by
+ # the cumulative weight put on them
+ weights = torch.cat(
+ [
+ torch.ones(data.size(0), device=data.device, dtype=data.dtype),
+ self.choice_weights,
+ ]
+ )
+ pool = torch.cat([data, self.choices])
+ samples = weights.multinomial(self.num_choices)
+ self.choices.copy_(pool[samples])
+
+ # The weights are the cumulative counts, divided by the number of choices
+ self.choice_weights.add_(data.size(0) / self.num_choices)
+
+ def compute(self) -> torch.Tensor:
+ # In the ddp setting, there are "too many" choices, so we sample
+ if self.choices.size(0) > self.num_choices:
+ samples = self.choice_weights.multinomial(self.num_choices)
+ return self.choices[samples]
+ return self.choices
+
+
+class DistanceSampler(Metric):
+ """
+ The distance sampler may be used for kmeans++ initialization, to iteratively select centroids
+ according to their squared distances to existing choices.
+
+ Computing the distance to existing choices is not part of this sampler. Within each "cycle", it
+ computes a given number of candidates. Candidates are sampled independently and may be
+ duplicates.
+ """
+
+ full_state_update = False
+
+ def __init__(
+ self,
+ num_choices: int,
+ num_features: int,
+ *,
+ dist_sync_fn: Optional[Callable[[Any], Any]] = None,
+ ):
+ super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
+
+ self.num_choices = num_choices
+ self.num_features = num_features
+
+ self.choices: torch.Tensor
+ self.add_state("choices", torch.empty(num_choices, num_features), dist_reduce_fx="cat")
+
+ # Cumulative distance is the same for all choices
+ self.cumulative_squared_distance: torch.Tensor
+ self.add_state("cumulative_squared_distance", torch.zeros(1), dist_reduce_fx="cat")
+
+ def update(self, data: torch.Tensor, shortest_distances: torch.Tensor) -> None:
+ eps = torch.finfo(data.dtype).eps
+ squared_distances = shortest_distances.square()
+
+ # For all choices, check if we should use a sample from the data or the existing choice
+ data_dist = squared_distances.sum()
+ cum_dist = data_dist + eps + self.cumulative_squared_distance
+ use_choice_from_data = (
+ torch.rand(self.num_choices, device=data.device, dtype=data.dtype) * cum_dist < data_dist + eps
+ )
+
+ # Then, we sample from the data `num_choices` times and replace if needed
+ choices = (squared_distances + eps).multinomial(self.num_choices, replacement=True)
+ self.choices.masked_scatter_(use_choice_from_data.unsqueeze(1), data[choices[use_choice_from_data]].float())
+
+ # In any case, the cumulative distances are updated
+ self.cumulative_squared_distance.add_(data_dist)
+
+ def compute(self) -> torch.Tensor:
+ # Upon computation, we sample if there is more than one choice (ddp setting)
+ if self.choices.size(0) > self.num_choices:
+ # choices now have shape [num_choices, num_processes, num_features]
+ choices = self.choices.reshape(-1, self.num_choices, self.num_features).transpose(0, 1)
+ # For each choice, we sample across processes
+ choice_indices = torch.arange(self.num_choices, device=self.choices.device)
+ process_indices = self.cumulative_squared_distance.multinomial(self.num_choices, replacement=True)
+ return choices[choice_indices, process_indices]
+ # Otherwise, we can return the choices
+ return self.choices
+
+
+class BatchSummer(Metric):
+ """Sums the values for a batch of items independently."""
+
+ full_state_update = True
+
+ def __init__(self, num_values: int, *, dist_sync_fn: Optional[Callable[[Any], Any]] = None):
+ super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
+
+ self.sums: torch.Tensor
+ self.add_state("sums", torch.zeros(num_values), dist_reduce_fx="sum")
+
+ def update(self, values: torch.Tensor) -> None:
+ """Update the sum of the values."""
+ self.sums.add_(values.sum(0))
+
+ def compute(self) -> torch.Tensor:
+ """Compute the sum of the values."""
+ return self.sums
+
+
+class BatchAverager(Metric):
+ """Averages the values for a batch of items independently."""
+
+ full_state_update = False
+
+ def __init__(
+ self,
+ num_values: int,
+ for_variance: bool,
+ *,
+ dist_sync_fn: Optional[Callable[[Any], Any]] = None,
+ ):
+ super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
+
+ self.for_variance = for_variance
+
+ self.sums: torch.Tensor
+ self.add_state("sums", torch.zeros(num_values), dist_reduce_fx="sum")
+
+ self.counts: torch.Tensor
+ self.add_state("counts", torch.zeros(num_values), dist_reduce_fx="sum")
+
+ def update(self, values: torch.Tensor) -> None:
+ self.sums.add_(values.sum(0))
+ self.counts.add_(values.size(0))
+
+ def compute(self) -> torch.Tensor:
+ return self.sums / (self.counts - 1 if self.for_variance else self.counts)
diff --git a/src/torchgmm/clustering/kmeans/model.py b/src/torchgmm/clustering/kmeans/model.py
new file mode 100644
index 0000000..d264be6
--- /dev/null
+++ b/src/torchgmm/clustering/kmeans/model.py
@@ -0,0 +1,76 @@
+from dataclasses import dataclass
+from typing import Tuple
+
+import torch
+from torch import jit, nn
+
+from torchgmm.base.nn import Configurable
+
+
+@dataclass
+class KMeansModelConfig:
+ """
+ Configuration class for a K-Means model.
+
+ See Also
+ --------
+ :class:`KMeansModel`
+ """
+
+ #: The number of clusters.
+ num_clusters: int
+ #: The number of features of each cluster.
+ num_features: int
+
+
+class KMeansModel(Configurable[KMeansModelConfig], nn.Module):
+ """
+ PyTorch module for the K-Means model.
+
+ The centroids managed by this model are non-trainable parameters.
+ """
+
+ def __init__(self, config: KMeansModelConfig):
+ """
+ Args:
+ config: The configuration to use for initializing the module's buffers.
+ """
+ super().__init__(config)
+
+ #: The centers of all clusters, buffer of shape ``[num_clusters, num_features].``
+ self.centroids: torch.Tensor
+ self.register_buffer("centroids", torch.empty(config.num_clusters, config.num_features))
+
+ self.reset_parameters()
+
+ @jit.unused
+ def reset_parameters(self) -> None:
+ """
+ Resets the parameters of the KMeans model.
+
+ It samples all cluster centers from a standard Normal.
+ """
+ nn.init.normal_(self.centroids)
+
+ def forward(self, data: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Computes the distance of each datapoint to each centroid as well as the "inertia", the
+ squared distance of each datapoint to its closest centroid.
+
+ Args:
+ data: A tensor of shape ``[num_datapoints, num_features]`` for which to compute the
+ distances and inertia.
+
+ Returns
+ -------
+ - A tensor of shape ``[num_datapoints, num_centroids]`` with the distance from each
+ datapoint to each centroid.
+ - A tensor of shape ``[num_datapoints]`` with the assignments, i.e. the indices of
+ each datapoint's closest centroid.
+ - A tensor of shape ``[num_datapoints]`` with the inertia (squared distance to the
+ closest centroid) of each datapoint.
+ """
+ distances = torch.cdist(data, self.centroids)
+ assignments = distances.min(1, keepdim=True).indices
+ inertias = distances.gather(1, assignments).square()
+ return distances, assignments.squeeze(1), inertias.squeeze(1)
diff --git a/src/torchgmm/clustering/kmeans/types.py b/src/torchgmm/clustering/kmeans/types.py
new file mode 100644
index 0000000..523a277
--- /dev/null
+++ b/src/torchgmm/clustering/kmeans/types.py
@@ -0,0 +1,16 @@
+from __future__ import annotations
+
+from typing import Literal
+
+KMeansInitStrategy = Literal["random", "kmeans++"]
+KMeansInitStrategy.__doc__ = """
+Strategy for initializing KMeans centroids.
+
+- **random**: Centroids are sampled randomly from the data. This has complexity ``O(n)`` for ``n``
+ datapoints.
+- **kmeans++**: Centroids are computed iteratively. The first centroid is sampled randomly from
+ the data. Subsequently, centroids are sampled from the remaining datapoints with probability
+ proportional to ``D(x)^2`` where ``D(x)`` is the distance of datapoint ``x`` to the closest
+ centroid chosen so far. This has complexity ``O(kn)`` for ``k`` clusters and ``n`` datapoints.
+ If done on mini-batches, the complexity increases to ``O(k^2 n)``.
+"""
diff --git a/src/torchgmm/pl/__init__.py b/src/torchgmm/pl/__init__.py
deleted file mode 100644
index c2315dd..0000000
--- a/src/torchgmm/pl/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .basic import BasicClass, basic_plot
diff --git a/src/torchgmm/pl/basic.py b/src/torchgmm/pl/basic.py
deleted file mode 100644
index ed390ef..0000000
--- a/src/torchgmm/pl/basic.py
+++ /dev/null
@@ -1,63 +0,0 @@
-from anndata import AnnData
-
-
-def basic_plot(adata: AnnData) -> int:
- """Generate a basic plot for an AnnData object.
-
- Parameters
- ----------
- adata
- The AnnData object to preprocess.
-
- Returns
- -------
- Some integer value.
- """
- print("Import matplotlib and implement a plotting function here.")
- return 0
-
-
-class BasicClass:
- """A basic class.
-
- Parameters
- ----------
- adata
- The AnnData object to preprocess.
- """
-
- my_attribute: str = "Some attribute."
- my_other_attribute: int = 0
-
- def __init__(self, adata: AnnData):
- print("Implement a class here.")
-
- def my_method(self, param: int) -> int:
- """A basic method.
-
- Parameters
- ----------
- param
- A parameter.
-
- Returns
- -------
- Some integer value.
- """
- print("Implement a method here.")
- return 0
-
- def my_other_method(self, param: str) -> str:
- """Another basic method.
-
- Parameters
- ----------
- param
- A parameter.
-
- Returns
- -------
- Some integer value.
- """
- print("Implement a method here.")
- return ""
diff --git a/src/torchgmm/pp/__init__.py b/src/torchgmm/pp/__init__.py
deleted file mode 100644
index 5e7e293..0000000
--- a/src/torchgmm/pp/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .basic import basic_preproc
diff --git a/src/torchgmm/pp/basic.py b/src/torchgmm/pp/basic.py
deleted file mode 100644
index 5db1ec0..0000000
--- a/src/torchgmm/pp/basic.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from anndata import AnnData
-
-
-def basic_preproc(adata: AnnData) -> int:
- """Run a basic preprocessing on the AnnData object.
-
- Parameters
- ----------
- adata
- The AnnData object to preprocess.
-
- Returns
- -------
- Some integer value.
- """
- print("Implement a preprocessing function here.")
- return 0
diff --git a/src/torchgmm/tl/__init__.py b/src/torchgmm/tl/__init__.py
deleted file mode 100644
index 95a32cd..0000000
--- a/src/torchgmm/tl/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .basic import basic_tool
diff --git a/src/torchgmm/tl/basic.py b/src/torchgmm/tl/basic.py
deleted file mode 100644
index d215ade..0000000
--- a/src/torchgmm/tl/basic.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from anndata import AnnData
-
-
-def basic_tool(adata: AnnData) -> int:
- """Run a tool on the AnnData object.
-
- Parameters
- ----------
- adata
- The AnnData object to preprocess.
-
- Returns
- -------
- Some integer value.
- """
- print("Implement a tool to run on the AnnData object.")
- return 0
diff --git a/src/torchgmm/utils/__init__.py b/src/torchgmm/utils/__init__.py
new file mode 100644
index 0000000..5192d4f
--- /dev/null
+++ b/src/torchgmm/utils/__init__.py
@@ -0,0 +1,3 @@
+from .lightning_module import NonparametricLightningModule
+
+__all__ = ["NonparametricLightningModule"]
diff --git a/src/torchgmm/utils/lightning_module.py b/src/torchgmm/utils/lightning_module.py
new file mode 100644
index 0000000..32ea630
--- /dev/null
+++ b/src/torchgmm/utils/lightning_module.py
@@ -0,0 +1,60 @@
+from abc import ABC, abstractmethod
+from typing import List
+
+import pytorch_lightning as pl
+import torch
+from packaging import version
+from torch import nn
+
+
+class NonparametricLightningModule(pl.LightningModule, ABC):
+ """A lightning module which sets some defaults for training models with no parameters (i.e. only buffers that are optimized differently than via gradient descent)."""
+
+ def __init__(self):
+ super().__init__()
+ self.automatic_optimization = False
+
+ # Required parameter to make DDP training work
+ self.register_parameter("__ddp_dummy__", nn.Parameter(torch.empty(1)))
+
+ def configure_optimizers(self) -> None:
+ """Configure optimizers hook from PyTorch Lightning."""
+ return None
+
+ def training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
+ """Training step hook from PyTorch Lightning."""
+ self.nonparametric_training_step(batch, batch_idx)
+
+ if version.parse(pl.__version__) >= version.parse("2.0.0"):
+
+ def on_train_epoch_end(self) -> None:
+ """Training epoch end hook for PyTorch Lightning >= 2.0.0."""
+ self.nonparametric_training_epoch_end()
+ else:
+
+ def training_epoch_end(self, outputs: List[torch.Tensor]) -> None:
+ """Training epoch end hook for PyTorch Lightning < 2.0.0."""
+ self.nonparametric_training_epoch_end()
+
+ @abstractmethod
+ def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
+ """Training step that is not allowed to return any value."""
+
+ def nonparametric_training_epoch_end(self) -> None:
+ """
+ Training epoch end that is not passed any outputs.
+
+ Does nothing by default.
+ """
+
+ def all_gather_first(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Gathers the provided tensor from all processes.
+
+ If more than one process is available, chooses the value of the first process in every
+ process.
+ """
+ gathered = self.all_gather(x)
+ if gathered.dim() > x.dim():
+ return gathered[0]
+ return x
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/_data/__init__.py b/tests/_data/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/_data/gmm.py b/tests/_data/gmm.py
new file mode 100644
index 0000000..7d4e13e
--- /dev/null
+++ b/tests/_data/gmm.py
@@ -0,0 +1,19 @@
+# pylint: disable=missing-function-docstring
+from typing import Tuple
+
+import torch
+
+from torchgmm.bayes.core import CovarianceType
+from torchgmm.bayes.gmm import GaussianMixtureModel, GaussianMixtureModelConfig
+
+
+def sample_gmm(
+ num_datapoints: int, num_features: int, num_components: int, covariance_type: CovarianceType
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ config = GaussianMixtureModelConfig(num_components, num_features, covariance_type)
+ model = GaussianMixtureModel(config)
+
+ # Means and covariances can simply be scaled
+ model.means.mul_(torch.rand(num_components).unsqueeze(-1) * 10).add_(torch.rand(num_components).unsqueeze(-1) * 10)
+
+ return model.sample(num_datapoints), model.means
diff --git a/tests/_data/normal.py b/tests/_data/normal.py
new file mode 100644
index 0000000..a094cba
--- /dev/null
+++ b/tests/_data/normal.py
@@ -0,0 +1,28 @@
+# pylint: disable=missing-function-docstring
+from typing import List
+
+import torch
+
+
+def sample_data(counts: List[int], dims: List[int]) -> List[torch.Tensor]:
+ return [torch.randn(count, dim) for count, dim in zip(counts, dims)]
+
+
+def sample_means(counts: List[int], dims: List[int]) -> List[torch.Tensor]:
+ return [torch.randn(count, dim) for count, dim in zip(counts, dims)]
+
+
+def sample_spherical_covars(counts: List[int]) -> List[torch.Tensor]:
+ return [torch.rand(count) for count in counts]
+
+
+def sample_diag_covars(counts: List[int], dims: List[int]) -> List[torch.Tensor]:
+ return [torch.rand(count, dim).squeeze() for count, dim in zip(counts, dims)]
+
+
+def sample_full_covars(counts: List[int], dims: List[int]) -> List[torch.Tensor]:
+ result = []
+ for count, dim in zip(counts, dims):
+ A = torch.rand(count, dim * 10, dim)
+ result.append(A.permute(0, 2, 1).bmm(A).squeeze())
+ return result
diff --git a/tests/bayes/core/benchmark_log_normal.py b/tests/bayes/core/benchmark_log_normal.py
new file mode 100644
index 0000000..66cf06e
--- /dev/null
+++ b/tests/bayes/core/benchmark_log_normal.py
@@ -0,0 +1,142 @@
+# pylint: disable=missing-function-docstring
+import numpy as np
+import torch
+from pytest_benchmark.fixture import BenchmarkFixture # type: ignore
+from sklearn.mixture._gaussian_mixture import (
+ _compute_precision_cholesky, # type: ignore
+ _estimate_log_gaussian_prob, # type: ignore
+)
+from torch.distributions import MultivariateNormal
+
+from torchgmm.bayes.core import cholesky_precision, log_normal
+
+
+def test_log_normal_spherical(benchmark: BenchmarkFixture):
+ data = torch.randn(10000, 100)
+ means = torch.randn(50, 100)
+ precisions = cholesky_precision(torch.rand(50), "spherical")
+ benchmark(log_normal, data, means, precisions, covariance_type="spherical")
+
+
+def test_torch_log_normal_spherical(benchmark: BenchmarkFixture):
+ data = torch.randn(10000, 100)
+ means = torch.randn(50, 100)
+ covars = torch.rand(50)
+ covar_matrices = torch.stack([torch.eye(means.size(-1)) * c for c in covars])
+
+ cholesky = torch.linalg.cholesky(covar_matrices)
+ distribution = MultivariateNormal(means, scale_tril=cholesky, validate_args=False)
+ benchmark(distribution.log_prob, data.unsqueeze(1))
+
+
+def test_numpy_log_normal_spherical(benchmark: BenchmarkFixture):
+ data = np.random.randn(10000, 100)
+ means = np.random.randn(50, 100)
+ covars = np.random.rand(50)
+ benchmark(_estimate_log_gaussian_prob, data, means, covars, "spherical") # type: ignore
+
+
+# -------------------------------------------------------------------------------------------------
+
+
+def test_log_normal_diag(benchmark: BenchmarkFixture):
+ data = torch.randn(10000, 100)
+ means = torch.randn(50, 100)
+ precisions = cholesky_precision(torch.rand(50, 100), "diag")
+ benchmark(log_normal, data, means, precisions, covariance_type="diag")
+
+
+def test_torch_log_normal_diag(benchmark: BenchmarkFixture):
+ data = torch.randn(10000, 100)
+ means = torch.randn(50, 100)
+ covars = torch.rand(50, 100)
+ covar_matrices = torch.stack([torch.diag(c) for c in covars])
+
+ cholesky = torch.linalg.cholesky(covar_matrices)
+ distribution = MultivariateNormal(means, scale_tril=cholesky, validate_args=False)
+ benchmark(distribution.log_prob, data.unsqueeze(1))
+
+
+def test_numpy_log_normal_diag(benchmark: BenchmarkFixture):
+ data = np.random.randn(10000, 100)
+ means = np.random.randn(50, 100)
+ covars = np.random.rand(50, 100)
+ benchmark(_estimate_log_gaussian_prob, data, means, covars, "diag") # type: ignore
+
+
+# -------------------------------------------------------------------------------------------------
+
+
+def test_log_normal_full(benchmark: BenchmarkFixture):
+ data = torch.randn(10000, 100)
+ means = torch.randn(50, 100)
+ A = torch.randn(50, 1000, 100)
+ covars = A.permute(0, 2, 1).bmm(A)
+ precisions = cholesky_precision(covars, "full")
+ benchmark(log_normal, data, means, precisions, covariance_type="full")
+
+
+def test_torch_log_normal_full(benchmark: BenchmarkFixture):
+ data = torch.randn(10000, 100)
+ means = torch.randn(50, 100)
+ A = torch.randn(50, 1000, 100)
+ covars = A.permute(0, 2, 1).bmm(A)
+
+ cholesky = torch.linalg.cholesky(covars)
+ distribution = MultivariateNormal(means, scale_tril=cholesky, validate_args=False)
+ benchmark(distribution.log_prob, data.unsqueeze(1))
+
+
+def test_numpy_log_normal_full(benchmark: BenchmarkFixture):
+ data = np.random.randn(10000, 100)
+ means = np.random.randn(50, 100)
+ A = np.random.randn(50, 1000, 100)
+ covars = np.matmul(np.transpose(A, (0, 2, 1)), A)
+
+ precisions = _compute_precision_cholesky(covars, "full") # type: ignore
+ benchmark(
+ _estimate_log_gaussian_prob, # type: ignore
+ data,
+ means,
+ precisions,
+ covariance_type="full",
+ )
+
+
+# -------------------------------------------------------------------------------------------------
+
+
+def test_log_normal_tied(benchmark: BenchmarkFixture):
+ data = torch.randn(10000, 100)
+ means = torch.randn(50, 100)
+ A = torch.randn(1000, 100)
+ covars = A.t().mm(A)
+ precisions = cholesky_precision(covars, "tied")
+ benchmark(log_normal, data, means, precisions, covariance_type="tied")
+
+
+def test_torch_log_normal_tied(benchmark: BenchmarkFixture):
+ data = torch.randn(10000, 100)
+ means = torch.randn(50, 100)
+ A = torch.randn(1000, 100)
+ covars = A.t().mm(A)
+
+ cholesky = torch.linalg.cholesky(covars)
+ distribution = MultivariateNormal(means, scale_tril=cholesky, validate_args=False)
+ benchmark(distribution.log_prob, data.unsqueeze(1))
+
+
+def test_numpy_log_normal_tied(benchmark: BenchmarkFixture):
+ data = np.random.randn(10000, 100)
+ means = np.random.randn(50, 100)
+ A = np.random.randn(1000, 100)
+ covars = A.T.dot(A)
+
+ precisions = _compute_precision_cholesky(covars, "tied") # type: ignore
+ benchmark(
+ _estimate_log_gaussian_prob, # type: ignore
+ data,
+ means,
+ precisions,
+ covariance_type="tied",
+ )
diff --git a/tests/bayes/core/benchmark_precision_cholesky.py b/tests/bayes/core/benchmark_precision_cholesky.py
new file mode 100644
index 0000000..db6ed48
--- /dev/null
+++ b/tests/bayes/core/benchmark_precision_cholesky.py
@@ -0,0 +1,47 @@
+# pylint: disable=missing-function-docstring
+import numpy as np
+import torch
+from pytest_benchmark.fixture import BenchmarkFixture # type: ignore
+from sklearn.mixture._gaussian_mixture import _compute_precision_cholesky # type: ignore
+
+from torchgmm.bayes.core import cholesky_precision
+
+
+def test_cholesky_precision_spherical(benchmark: BenchmarkFixture):
+ covars = torch.rand(50)
+ benchmark(cholesky_precision, covars, "spherical")
+
+
+def test_numpy_cholesky_precision_spherical(benchmark: BenchmarkFixture):
+ covars = np.random.rand(50)
+ benchmark(_compute_precision_cholesky, covars, "spherical") # type: ignore
+
+
+# -------------------------------------------------------------------------------------------------
+
+
+def test_cholesky_precision_tied(benchmark: BenchmarkFixture):
+ A = torch.randn(10000, 100)
+ covars = A.t().mm(A)
+ benchmark(cholesky_precision, covars, "tied")
+
+
+def test_numpy_cholesky_precision_tied(benchmark: BenchmarkFixture):
+ A = np.random.randn(10000, 100)
+ covars = np.dot(A.T, A)
+ benchmark(_compute_precision_cholesky, covars, "tied") # type: ignore
+
+
+# -------------------------------------------------------------------------------------------------
+
+
+def test_cholesky_precision_full(benchmark: BenchmarkFixture):
+ A = torch.randn(50, 10000, 100)
+ covars = A.permute(0, 2, 1).bmm(A)
+ benchmark(cholesky_precision, covars, "full")
+
+
+def test_numpy_cholesky_precision_full(benchmark: BenchmarkFixture):
+ A = np.random.randn(50, 10000, 100)
+ covars = np.matmul(np.transpose(A, (0, 2, 1)), A)
+ benchmark(_compute_precision_cholesky, covars, "full") # type: ignore
diff --git a/tests/bayes/core/test_normal.py b/tests/bayes/core/test_normal.py
new file mode 100644
index 0000000..2facf52
--- /dev/null
+++ b/tests/bayes/core/test_normal.py
@@ -0,0 +1,277 @@
+# pylint: disable=missing-function-docstring
+import pytest
+import torch
+from sklearn.mixture._gaussian_mixture import (
+ _compute_log_det_cholesky, # type: ignore
+ _compute_precision_cholesky, # type: ignore
+)
+from tests._data.normal import (
+ sample_data,
+ sample_diag_covars,
+ sample_full_covars,
+ sample_means,
+ sample_spherical_covars,
+)
+from torch.distributions import MultivariateNormal
+
+from torchgmm.bayes.core import cholesky_precision, covariance, log_normal, sample_normal
+from torchgmm.bayes.core._jit import _cholesky_logdet # type: ignore
+
+# -------------------------------------------------------------------------------------------------
+# CHOLESKY PRECISIONS
+# -------------------------------------------------------------------------------------------------
+
+
+@pytest.mark.parametrize("covars", sample_spherical_covars([70, 5, 200]))
+def test_cholesky_precision_spherical(covars: torch.Tensor):
+ expected = _compute_precision_cholesky(covars.numpy(), "spherical") # type: ignore
+ actual = cholesky_precision(covars, "spherical")
+ assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual, rtol=1e-4, atol=1e-4)
+
+
+@pytest.mark.parametrize("covars", sample_diag_covars([70, 5, 200], [3, 50, 100]))
+def test_cholesky_precision_diag(covars: torch.Tensor):
+ expected = _compute_precision_cholesky(covars.numpy(), "diag") # type: ignore
+ actual = cholesky_precision(covars, "diag")
+ assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual, rtol=1e-4, atol=1e-4)
+
+
+@pytest.mark.parametrize("covars", sample_full_covars([70, 5, 200], [3, 50, 100]))
+def test_cholesky_precision_full(covars: torch.Tensor):
+ expected = _compute_precision_cholesky(covars.numpy(), "full") # type: ignore
+ actual = cholesky_precision(covars, "full")
+ assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual, rtol=1e-4, atol=1e-4)
+
+
+@pytest.mark.parametrize("covars", sample_full_covars([1, 1, 1], [3, 50, 100]))
+def test_cholesky_precision_tied(covars: torch.Tensor):
+ expected = _compute_precision_cholesky(covars.numpy(), "tied") # type: ignore
+ actual = cholesky_precision(covars, "tied")
+ assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual, rtol=1e-4, atol=1e-4)
+
+
+# -------------------------------------------------------------------------------------------------
+# COVARIANCES
+# -------------------------------------------------------------------------------------------------
+
+
+@pytest.mark.parametrize("covars", sample_spherical_covars([70, 5, 200]))
+def test_covariances_spherical(covars: torch.Tensor):
+ precision_cholesky = _compute_precision_cholesky(covars.numpy(), "spherical") # type: ignore
+ actual = covariance(torch.as_tensor(precision_cholesky, dtype=torch.float), "spherical")
+ assert torch.allclose(covars, actual)
+
+
+@pytest.mark.parametrize("covars", sample_diag_covars([70, 5, 200], [3, 50, 100]))
+def test_covariances_diag(covars: torch.Tensor):
+ precision_cholesky = _compute_precision_cholesky(covars.numpy(), "diag") # type: ignore
+ actual = covariance(torch.as_tensor(precision_cholesky, dtype=torch.float), "diag")
+ assert torch.allclose(covars, actual)
+
+
+@pytest.mark.parametrize("covars", sample_full_covars([70, 5, 200], [3, 50, 100]))
+def test_covariances_full(covars: torch.Tensor):
+ precision_cholesky = _compute_precision_cholesky(covars.numpy(), "full") # type: ignore
+ actual = covariance(torch.as_tensor(precision_cholesky, dtype=torch.double), "full")
+ assert torch.allclose(covars, covars.transpose(1, 2))
+ assert torch.allclose(covars.to(torch.double), actual)
+
+
+@pytest.mark.parametrize("covars", sample_full_covars([1, 1, 1], [3, 50, 100]))
+def test_covariances_tied(covars: torch.Tensor):
+ precision_cholesky = _compute_precision_cholesky(covars.numpy(), "tied") # type: ignore
+ actual = covariance(torch.as_tensor(precision_cholesky, dtype=torch.double), "tied")
+ assert torch.allclose(covars, covars.T)
+ assert torch.allclose(covars.to(torch.double), actual)
+
+
+# -------------------------------------------------------------------------------------------------
+# CHOLESKY LOG DETERMINANTS
+# -------------------------------------------------------------------------------------------------
+
+
+@pytest.mark.parametrize("covars", sample_spherical_covars([70, 5, 200]))
+def test_cholesky_logdet_spherical(covars: torch.Tensor):
+ expected = _compute_log_det_cholesky( # type: ignore
+ _compute_precision_cholesky(covars.numpy(), "spherical"),
+ "spherical",
+ 100, # type: ignore
+ )
+ actual = _cholesky_logdet( # type: ignore
+ 100,
+ cholesky_precision(covars, "spherical"),
+ "spherical",
+ )
+ assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual, atol=1e-4)
+
+
+@pytest.mark.parametrize("covars", sample_diag_covars([70, 5, 200], [3, 50, 100]))
+def test_cholesky_logdet_diag(covars: torch.Tensor):
+ expected = _compute_log_det_cholesky( # type: ignore
+ _compute_precision_cholesky(covars.numpy(), "diag"), # type: ignore
+ "diag",
+ covars.size(1),
+ )
+ actual = _cholesky_logdet( # type: ignore
+ covars.size(1),
+ cholesky_precision(covars, "diag"),
+ "diag",
+ )
+ assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual)
+
+
+@pytest.mark.parametrize("covars", sample_full_covars([70, 5, 200], [3, 50, 100]))
+def test_cholesky_logdet_full(covars: torch.Tensor):
+ expected = _compute_log_det_cholesky( # type: ignore
+ _compute_precision_cholesky(covars.numpy(), "full"), # type: ignore
+ "full",
+ covars.size(1),
+ )
+ actual = _cholesky_logdet( # type: ignore
+ covars.size(1),
+ cholesky_precision(covars, "full"),
+ "full",
+ )
+ assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual)
+
+
+@pytest.mark.parametrize("covars", sample_full_covars([1, 1, 1], [3, 50, 100]))
+def test_cholesky_logdet_tied(covars: torch.Tensor):
+ expected = _compute_log_det_cholesky( # type: ignore
+ _compute_precision_cholesky(covars.numpy(), "tied"), # type: ignore
+ "tied",
+ covars.size(0),
+ )
+ actual = _cholesky_logdet( # type: ignore
+ covars.size(0),
+ cholesky_precision(covars, "tied"),
+ "tied",
+ )
+ assert torch.allclose(torch.as_tensor(expected, dtype=torch.float), actual)
+
+
+# -------------------------------------------------------------------------------------------------
+# LOG NORMAL
+# -------------------------------------------------------------------------------------------------
+
+
+@pytest.mark.parametrize(
+ "x, means, covars",
+ zip(
+ sample_data([10, 50, 100], [3, 50, 100]),
+ sample_means([70, 5, 200], [3, 50, 100]),
+ sample_spherical_covars([70, 5, 200]),
+ ),
+)
+def test_log_normal_spherical(x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor):
+ covar_matrices = torch.stack([torch.eye(means.size(-1)) * c for c in covars])
+ precisions_cholesky = cholesky_precision(covars, "spherical")
+ actual = log_normal(x, means, precisions_cholesky, covariance_type="spherical")
+ _assert_log_prob(actual, x, means, covar_matrices)
+
+
+@pytest.mark.parametrize(
+ "x, means, covars",
+ zip(
+ sample_data([10, 50, 100], [3, 50, 100]),
+ sample_means([70, 5, 200], [3, 50, 100]),
+ sample_diag_covars([70, 5, 200], [3, 50, 100]),
+ ),
+)
+def test_log_normal_diag(x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor):
+ covar_matrices = torch.stack([torch.diag(c) for c in covars])
+ precisions_cholesky = cholesky_precision(covars, "diag")
+ actual = log_normal(x, means, precisions_cholesky, covariance_type="diag")
+ _assert_log_prob(actual, x, means, covar_matrices)
+
+
+@pytest.mark.parametrize(
+ "x, means, covars",
+ zip(
+ sample_data([10, 50, 100], [3, 50, 100]),
+ sample_means([70, 5, 200], [3, 50, 100]),
+ sample_full_covars([70, 5, 200], [3, 50, 100]),
+ ),
+)
+def test_log_normal_full(x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor):
+ precisions_cholesky = cholesky_precision(covars, "full")
+ actual = log_normal(x, means, precisions_cholesky, covariance_type="full")
+ _assert_log_prob(actual.float(), x, means, covars)
+
+
+@pytest.mark.parametrize(
+ "x, means, covars",
+ zip(
+ sample_data([10, 50, 100], [3, 50, 100]),
+ sample_means([70, 5, 200], [3, 50, 100]),
+ sample_full_covars([1, 1, 1], [3, 50, 100]),
+ ),
+)
+def test_log_normal_tied(x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor):
+ precisions_cholesky = cholesky_precision(covars, "tied")
+ actual = log_normal(x, means, precisions_cholesky, covariance_type="tied")
+ _assert_log_prob(actual, x, means, covars)
+
+
+# -------------------------------------------------------------------------------------------------
+# SAMPLING
+# -------------------------------------------------------------------------------------------------
+
+
+@pytest.mark.flaky(max_runs=3, min_passes=1)
+def test_sample_normal_spherical():
+ mean = torch.tensor([1.5, 3.5])
+ covar = torch.tensor(4.0)
+ target_covar = torch.tensor([[4.0, 0.0], [0.0, 4.0]])
+
+ n = 1_000_000
+ precisions = cholesky_precision(covar, "spherical")
+ samples = sample_normal(n, mean, precisions, "spherical")
+
+ sample_mean = samples.mean(0)
+ sample_covar = (samples - sample_mean).t().matmul(samples - sample_mean) / n
+
+ assert torch.allclose(mean, sample_mean, atol=1e-2)
+ assert torch.allclose(target_covar, sample_covar, atol=1e-2)
+
+
+@pytest.mark.flaky(max_runs=3, min_passes=1)
+def test_sample_normal_diag():
+ mean = torch.tensor([1.5, 3.5])
+ covar = torch.tensor([0.5, 4.5])
+ target_covar = torch.tensor([[0.5, 0.0], [0.0, 4.5]])
+
+ n = 1_000_000
+ precisions = cholesky_precision(covar, "diag")
+ samples = sample_normal(n, mean, precisions, "diag")
+
+ sample_mean = samples.mean(0)
+ sample_covar = (samples - sample_mean).t().matmul(samples - sample_mean) / n
+
+ assert torch.allclose(mean, sample_mean, atol=1e-2)
+ assert torch.allclose(target_covar, sample_covar, atol=1e-2)
+
+
+@pytest.mark.flaky(max_runs=3, min_passes=1)
+def test_sample_normal_full():
+ mean = torch.tensor([1.5, 3.5])
+ covar = torch.tensor([[4.0, 2.5], [2.5, 2.0]])
+
+ n = 1_000_000
+ precisions = cholesky_precision(covar, "tied")
+ samples = sample_normal(n, mean, precisions, "full")
+
+ sample_mean = samples.mean(0)
+ sample_covar = (samples - sample_mean).t().matmul(samples - sample_mean) / n
+
+ assert torch.allclose(mean, sample_mean, atol=1e-2)
+ assert torch.allclose(covar, sample_covar, atol=1e-2)
+
+
+# -------------------------------------------------------------------------------------------------
+
+
+def _assert_log_prob(actual: torch.Tensor, x: torch.Tensor, means: torch.Tensor, covars: torch.Tensor) -> None:
+ distribution = MultivariateNormal(means, covars)
+ expected = distribution.log_prob(x.unsqueeze(1))
+ assert torch.allclose(actual, expected, rtol=1e-3)
diff --git a/tests/bayes/gmm/benchmark_gmm_estimator.py b/tests/bayes/gmm/benchmark_gmm_estimator.py
new file mode 100644
index 0000000..9140964
--- /dev/null
+++ b/tests/bayes/gmm/benchmark_gmm_estimator.py
@@ -0,0 +1,135 @@
+# pylint: disable=missing-function-docstring
+from typing import Optional
+
+import pytest
+import pytorch_lightning as pl
+import torch
+from pytest_benchmark.fixture import BenchmarkFixture # type: ignore
+from sklearn.mixture import GaussianMixture as SklearnGaussianMixture # type: ignore
+from tests._data.gmm import sample_gmm
+
+from torchgmm.bayes import GaussianMixture
+from torchgmm.bayes.core.types import CovarianceType
+
+
+@pytest.mark.parametrize(
+ ("num_datapoints", "num_features", "num_components", "covariance_type"),
+ [
+ (10000, 8, 4, "diag"),
+ (10000, 8, 4, "tied"),
+ (10000, 8, 4, "full"),
+ (100000, 32, 16, "diag"),
+ (100000, 32, 16, "tied"),
+ (100000, 32, 16, "full"),
+ (1000000, 64, 64, "diag"),
+ ],
+)
+def test_sklearn(
+ benchmark: BenchmarkFixture,
+ num_datapoints: int,
+ num_features: int,
+ num_components: int,
+ covariance_type: CovarianceType,
+):
+ pl.seed_everything(0)
+ data, means = sample_gmm(num_datapoints, num_features, num_components, covariance_type)
+
+ estimator = SklearnGaussianMixture(
+ num_components,
+ covariance_type=covariance_type,
+ tol=0,
+ n_init=1,
+ max_iter=100,
+ reg_covar=1e-2,
+ init_params="random",
+ means_init=means.numpy(),
+ )
+ benchmark(estimator.fit, data.numpy())
+
+
+@pytest.mark.parametrize(
+ ("num_datapoints", "num_features", "num_components", "covariance_type", "batch_size"),
+ [
+ (10000, 8, 4, "diag", None),
+ (10000, 8, 4, "tied", None),
+ (10000, 8, 4, "full", None),
+ (100000, 32, 16, "diag", None),
+ (100000, 32, 16, "tied", None),
+ (100000, 32, 16, "full", None),
+ (1000000, 64, 64, "diag", None),
+ (10000, 8, 4, "diag", 1000),
+ (10000, 8, 4, "tied", 1000),
+ (10000, 8, 4, "full", 1000),
+ (100000, 32, 16, "diag", 10000),
+ (100000, 32, 16, "tied", 10000),
+ (100000, 32, 16, "full", 10000),
+ (1000000, 64, 64, "diag", 100000),
+ ],
+)
+def test_torchgmm(
+ benchmark: BenchmarkFixture,
+ num_datapoints: int,
+ num_features: int,
+ num_components: int,
+ covariance_type: CovarianceType,
+ batch_size: Optional[int],
+):
+ pl.seed_everything(0)
+ data, means = sample_gmm(num_datapoints, num_features, num_components, covariance_type)
+
+ estimator = GaussianMixture(
+ num_components,
+ covariance_type=covariance_type,
+ init_means=means,
+ convergence_tolerance=0,
+ covariance_regularization=1e-2,
+ batch_size=batch_size,
+ trainer_params={"max_epochs": 100, "accelerator": "cpu"},
+ )
+ benchmark(estimator.fit, data)
+
+
+@pytest.mark.parametrize(
+ ("num_datapoints", "num_features", "num_components", "covariance_type", "batch_size"),
+ [
+ (10000, 8, 4, "diag", None),
+ (10000, 8, 4, "tied", None),
+ (10000, 8, 4, "full", None),
+ (100000, 32, 16, "diag", None),
+ (100000, 32, 16, "tied", None),
+ (100000, 32, 16, "full", None),
+ (1000000, 64, 64, "diag", None),
+ (10000, 8, 4, "diag", 1000),
+ (10000, 8, 4, "tied", 1000),
+ (10000, 8, 4, "full", 1000),
+ (100000, 32, 16, "diag", 10000),
+ (100000, 32, 16, "tied", 10000),
+ (100000, 32, 16, "full", 10000),
+ (1000000, 64, 64, "diag", 100000),
+ (1000000, 64, 64, "tied", 100000),
+ ],
+)
+def test_torchgmm_gpu(
+ benchmark: BenchmarkFixture,
+ num_datapoints: int,
+ num_features: int,
+ num_components: int,
+ covariance_type: CovarianceType,
+ batch_size: Optional[int],
+):
+ # Initialize GPU
+ torch.empty(1, device="cuda:0")
+
+ pl.seed_everything(0)
+ data, means = sample_gmm(num_datapoints, num_features, num_components, covariance_type)
+
+ estimator = GaussianMixture(
+ num_components,
+ covariance_type=covariance_type,
+ init_means=means,
+ convergence_tolerance=0,
+ covariance_regularization=1e-2,
+ batch_size=batch_size,
+ trainer_params={"max_epochs": 100, "accelerator": "gpu", "devices": 1},
+ )
+ benchmark(estimator.fit, data)
diff --git a/tests/bayes/gmm/test_gmm_estimator.py b/tests/bayes/gmm/test_gmm_estimator.py
new file mode 100644
index 0000000..701d9c2
--- /dev/null
+++ b/tests/bayes/gmm/test_gmm_estimator.py
@@ -0,0 +1,96 @@
+# pylint: disable=missing-function-docstring
+import math
+from typing import Optional
+
+import pytest
+import torch
+from sklearn.mixture import GaussianMixture as SklearnGaussianMixture # type: ignore
+from tests._data.gmm import sample_gmm
+
+from torchgmm.bayes import GaussianMixture
+from torchgmm.bayes.core import CovarianceType
+
+torch.set_num_threads(15)
+
+
+def test_fit_model_config():
+ estimator = GaussianMixture()
+ data = torch.randn(1000, 4)
+ estimator.fit(data)
+
+ assert estimator.model_.config.num_components == 1
+ assert estimator.model_.config.num_features == 4
+
+
+@pytest.mark.parametrize("batch_size", [2, None])
+def test_fit_num_iter(batch_size: Optional[int]):
+ # For the following data, K-means will find centroids [0.5, 3.5]. The estimator first computes
+ # the NLL (first iteration), afterwards there is no improvmement in the NLL (second iteration).
+ data = torch.as_tensor([[0.0], [1.0], [3.0], [4.0]])
+ estimator = GaussianMixture(
+ 2,
+ batch_size=batch_size,
+ )
+ estimator.fit(data)
+
+ assert estimator.num_iter_ == 2
+
+
+@pytest.mark.flaky(max_runs=3, min_passes=1)
+@pytest.mark.parametrize(
+ ("batch_size", "max_epochs", "converged"),
+ [(2, 1, False), (2, 3, True), (None, 1, False), (None, 3, True)],
+)
+def test_fit_converged(batch_size: Optional[int], max_epochs: int, converged: bool):
+ data = torch.as_tensor([[0.0], [1.0], [3.0], [4.0]])
+
+ estimator = GaussianMixture(
+ 2,
+ batch_size=batch_size,
+ trainer_params={"max_epochs": max_epochs},
+ )
+ estimator.fit(data)
+ assert estimator.converged_ == converged
+
+
+@pytest.mark.flaky(max_runs=25, min_passes=1)
+@pytest.mark.parametrize(
+ ("num_datapoints", "batch_size", "num_features", "num_components", "covariance_type"),
+ [
+ (10000, 10000, 4, 4, "spherical"),
+ (10000, 10000, 4, 4, "diag"),
+ (10000, 10000, 4, 4, "tied"),
+ (10000, 10000, 4, 4, "full"),
+ (10000, 1000, 4, 4, "spherical"),
+ (10000, 1000, 4, 4, "diag"),
+ (10000, 1000, 4, 4, "tied"),
+ (10000, 1000, 4, 4, "full"),
+ ],
+)
+def test_fit_nll(
+ num_datapoints: int,
+ batch_size: int,
+ num_features: int,
+ num_components: int,
+ covariance_type: CovarianceType,
+):
+ data, _ = sample_gmm(
+ num_datapoints=num_datapoints,
+ num_features=num_features,
+ num_components=num_components,
+ covariance_type=covariance_type,
+ )
+
+ # Ours
+ estimator = GaussianMixture(
+ num_components,
+ covariance_type=covariance_type,
+ batch_size=batch_size,
+ )
+ ours_nll = estimator.fit(data).score(data)
+
+ # Sklearn
+ gmm = SklearnGaussianMixture(num_components, covariance_type=covariance_type)
+ sklearn_nll = gmm.fit(data.numpy()).score(data.numpy())
+
+ assert math.isclose(ours_nll, -sklearn_nll, rel_tol=0.01, abs_tol=0.01)
diff --git a/tests/bayes/gmm/test_gmm_metrics.py b/tests/bayes/gmm/test_gmm_metrics.py
new file mode 100644
index 0000000..8b28fcc
--- /dev/null
+++ b/tests/bayes/gmm/test_gmm_metrics.py
@@ -0,0 +1,102 @@
+# pylint: disable=protected-access,missing-function-docstring
+from typing import Any, Callable
+
+import numpy as np
+import sklearn.mixture._gaussian_mixture as skgmm # type: ignore
+import torch
+
+from torchgmm.bayes.core import CovarianceType
+from torchgmm.bayes.gmm.metrics import CovarianceAggregator, MeanAggregator, PriorAggregator
+
+torch.set_num_threads(15)
+
+
+def test_prior_aggregator():
+ aggregator = PriorAggregator(3)
+ aggregator.reset()
+
+ # Step 1: single batch
+ responsibilities1 = torch.tensor([[0.3, 0.3, 0.4], [0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
+ actual = aggregator.forward(responsibilities1)
+ expected = torch.tensor([0.5, 0.3, 0.2])
+ assert torch.allclose(actual, expected)
+
+ # Step 2: batch aggregation
+ responsibilities2 = torch.tensor([[0.7, 0.2, 0.1], [0.5, 0.4, 0.1]])
+ aggregator.update(responsibilities2)
+ actual = aggregator.compute()
+ expected = torch.tensor([0.54, 0.3, 0.16])
+ assert torch.allclose(actual, expected)
+
+
+def test_mean_aggregator():
+ aggregator = MeanAggregator(3, 2)
+ aggregator.reset()
+
+ # Step 1: single batch
+ data1 = torch.tensor([[5.0, 2.0], [3.0, 4.0], [1.0, 0.0]])
+ responsibilities1 = torch.tensor([[0.3, 0.3, 0.4], [0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
+ actual = aggregator.forward(data1, responsibilities1)
+ expected = torch.tensor([[2.8667, 2.5333], [2.5556, 1.1111], [4.0, 2.0]])
+ assert torch.allclose(actual, expected, atol=1e-4)
+
+ # Step 2: batch aggregation
+ data2 = torch.tensor([[8.0, 2.5], [1.5, 4.0]])
+ responsibilities2 = torch.tensor([[0.7, 0.2, 0.1], [0.5, 0.4, 0.1]])
+ aggregator.update(data2, responsibilities2)
+ actual = aggregator.compute()
+ expected = torch.tensor([[3.9444, 2.7963], [3.0, 2.0667], [4.1875, 2.3125]])
+ assert torch.allclose(actual, expected, atol=1e-4)
+
+
+def test_covariance_aggregator_spherical():
+ _test_covariance("spherical", skgmm._estimate_gaussian_covariances_spherical) # type: ignore
+
+
+def test_covariance_aggregator_diag():
+ _test_covariance("diag", skgmm._estimate_gaussian_covariances_diag) # type: ignore
+
+
+def test_covariance_aggregator_tied():
+ _test_covariance("tied", skgmm._estimate_gaussian_covariances_tied) # type: ignore
+
+
+def test_covariance_aggregator_full():
+ _test_covariance("full", skgmm._estimate_gaussian_covariances_full) # type: ignore
+
+
+def _test_covariance(
+ covariance_type: CovarianceType,
+ sk_aggregator: Callable[[Any, Any, Any, Any, Any], Any],
+):
+ reg = 1e-5
+ aggregator = CovarianceAggregator(3, 2, covariance_type, reg=reg)
+ aggregator.reset()
+ means = torch.tensor([[3.0, 2.5], [2.5, 1.0], [4.0, 2.0]])
+
+ # Step 1: single batch
+ data1 = torch.tensor([[5.0, 2.0], [3.0, 4.0], [1.0, 0.0]])
+ responsibilities1 = torch.tensor([[0.3, 0.3, 0.4], [0.8, 0.1, 0.1], [0.4, 0.5, 0.1]])
+ actual = aggregator.forward(data1, responsibilities1, means)
+ expected = sk_aggregator( # type: ignore
+ responsibilities1.numpy(),
+ data1.numpy(),
+ responsibilities1.sum(0).numpy(),
+ means.numpy(),
+ reg,
+ ).astype(np.float32)
+ assert torch.allclose(actual, torch.from_numpy(expected))
+
+ # Step 2: batch aggregation
+ data2 = torch.tensor([[8.0, 2.5], [1.5, 4.0]])
+ responsibilities2 = torch.tensor([[0.7, 0.2, 0.1], [0.5, 0.4, 0.1]])
+ aggregator.update(data2, responsibilities2, means)
+ actual = aggregator.compute()
+ expected = sk_aggregator( # type: ignore
+ torch.cat([responsibilities1, responsibilities2]).numpy(),
+ torch.cat([data1, data2]).numpy(),
+ (responsibilities1.sum(0) + responsibilities2.sum(0)).numpy(),
+ means.numpy(),
+ reg,
+ ).astype(np.float32)
+ assert torch.allclose(actual, torch.from_numpy(expected))
diff --git a/tests/bayes/gmm/test_gmm_model.py b/tests/bayes/gmm/test_gmm_model.py
new file mode 100644
index 0000000..7a4a54f
--- /dev/null
+++ b/tests/bayes/gmm/test_gmm_model.py
@@ -0,0 +1,10 @@
+# pylint: disable=missing-function-docstring
+from torch import jit
+
+from torchgmm.bayes.gmm import GaussianMixtureModel, GaussianMixtureModelConfig
+
+
+def test_compile():
+ config = GaussianMixtureModelConfig(num_components=2, num_features=3, covariance_type="full")
+ model = GaussianMixtureModel(config)
+ jit.script(model)
diff --git a/tests/clustering/kmeans/benchmark_kmeans_estimator.py b/tests/clustering/kmeans/benchmark_kmeans_estimator.py
new file mode 100644
index 0000000..e95218a
--- /dev/null
+++ b/tests/clustering/kmeans/benchmark_kmeans_estimator.py
@@ -0,0 +1,127 @@
+# pylint: disable=missing-function-docstring
+from typing import Optional
+
+import pytest
+import pytorch_lightning as pl
+import torch
+from pytest_benchmark.fixture import BenchmarkFixture # type: ignore
+from sklearn.cluster import KMeans as SklearnKMeans # type: ignore
+from tests._data.gmm import sample_gmm
+
+from torchgmm.clustering import KMeans
+from torchgmm.clustering.kmeans.types import KMeansInitStrategy
+
+
+@pytest.mark.parametrize(
+ ("num_datapoints", "num_features", "num_centroids", "init_strategy"),
+ [
+ (10000, 8, 4, "k-means++"),
+ (100000, 32, 16, "k-means++"),
+ (1000000, 64, 64, "k-means++"),
+ (10000000, 128, 128, "k-means++"),
+ (10000, 8, 4, "random"),
+ (100000, 32, 16, "random"),
+ (1000000, 64, 64, "random"),
+ (10000000, 128, 128, "random"),
+ ],
+)
+def test_sklearn(
+ benchmark: BenchmarkFixture,
+ num_datapoints: int,
+ num_features: int,
+ num_centroids: int,
+ init_strategy: str,
+):
+ pl.seed_everything(0)
+ data, _ = sample_gmm(num_datapoints, num_features, num_centroids, "spherical")
+
+ estimator = SklearnKMeans(
+ num_centroids,
+ algorithm="lloyd",
+ n_init=1,
+ max_iter=100,
+ tol=0,
+ init=init_strategy,
+ )
+ benchmark(estimator.fit, data.numpy())
+
+
+@pytest.mark.parametrize(
+ ("num_datapoints", "batch_size", "num_features", "num_centroids", "init_strategy"),
+ [
+ (10000, None, 8, 4, "kmeans++"),
+ (10000, 1000, 8, 4, "kmeans++"),
+ (100000, None, 32, 16, "kmeans++"),
+ (100000, 10000, 32, 16, "kmeans++"),
+ (1000000, None, 64, 64, "kmeans++"),
+ (1000000, 100000, 64, 64, "kmeans++"),
+ (10000, None, 8, 4, "random"),
+ (10000, 1000, 8, 4, "random"),
+ (100000, None, 32, 16, "random"),
+ (100000, 10000, 32, 16, "random"),
+ (1000000, None, 64, 64, "random"),
+ (1000000, 100000, 64, 64, "random"),
+ ],
+)
+def test_torchgmm(
+ benchmark: BenchmarkFixture,
+ num_datapoints: int,
+ batch_size: Optional[int],
+ num_features: int,
+ num_centroids: int,
+ init_strategy: KMeansInitStrategy,
+):
+ pl.seed_everything(0)
+ data, _ = sample_gmm(num_datapoints, num_features, num_centroids, "spherical")
+
+ estimator = KMeans(
+ num_centroids,
+ init_strategy=init_strategy,
+ batch_size=batch_size,
+ trainer_params={"max_epochs": 100},
+ )
+ benchmark(estimator.fit, data)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available")
+@pytest.mark.parametrize(
+ ("num_datapoints", "batch_size", "num_features", "num_centroids", "init_strategy"),
+ [
+ (10000, None, 8, 4, "kmeans++"),
+ (10000, 1000, 8, 4, "kmeans++"),
+ (100000, None, 32, 16, "kmeans++"),
+ (100000, 10000, 32, 16, "kmeans++"),
+ (1000000, None, 64, 64, "kmeans++"),
+ (1000000, 100000, 64, 64, "kmeans++"),
+ (10000000, 1000000, 128, 128, "kmeans++"),
+ (10000, None, 8, 4, "random"),
+ (10000, 1000, 8, 4, "random"),
+ (100000, None, 32, 16, "random"),
+ (100000, 10000, 32, 16, "random"),
+ (1000000, None, 64, 64, "random"),
+ (1000000, 100000, 64, 64, "random"),
+ (10000000, 1000000, 128, 128, "random"),
+ ],
+)
+def test_torchgmm_gpu(
+ benchmark: BenchmarkFixture,
+ num_datapoints: int,
+ batch_size: Optional[int],
+ num_features: int,
+ num_centroids: int,
+ init_strategy: KMeansInitStrategy,
+):
+ # Initialize GPU
+ torch.empty(1, device="cuda:0")
+
+ pl.seed_everything(0)
+ data, _ = sample_gmm(num_datapoints, num_features, num_centroids, "spherical")
+
+ estimator = KMeans(
+ num_centroids,
+ init_strategy=init_strategy,
+ batch_size=batch_size,
+ convergence_tolerance=0,
+ trainer_params={"max_epochs": 100, "accelerator": "gpu", "devices": 1},
+ )
+ benchmark(estimator.fit, data)
diff --git a/tests/clustering/kmeans/test_kmeans_estimator.py b/tests/clustering/kmeans/test_kmeans_estimator.py
new file mode 100644
index 0000000..3afb631
--- /dev/null
+++ b/tests/clustering/kmeans/test_kmeans_estimator.py
@@ -0,0 +1,87 @@
+# pylint: disable=missing-function-docstring
+import math
+from typing import Optional
+
+import pytest
+import torch
+from sklearn.cluster import KMeans as SklearnKMeans # type: ignore
+from tests._data.gmm import sample_gmm
+
+from torchgmm.clustering import KMeans
+
+torch.set_num_threads(15)
+
+
+def test_fit_automatic_config():
+ estimator = KMeans(4)
+ data = torch.cat([torch.randn(1000, 3) * 0.1 - 100, torch.randn(1000, 3) * 0.1 + 100])
+ estimator.fit(data)
+ assert estimator.model_.config.num_clusters == 4
+ assert estimator.model_.config.num_features == 3
+
+
+def test_fit_num_iter():
+ # The k-means++ iterations should find the centroids. Afterwards, it should only take a single
+ # epoch until the centroids do not change anymore.
+ data = torch.cat([torch.randn(1000, 4) * 0.1 - 100, torch.randn(1000, 4) * 0.1 + 100])
+
+ estimator = KMeans(2)
+ estimator.fit(data)
+
+ assert estimator.num_iter_ == 1
+
+
+@pytest.mark.flaky(max_runs=2, min_passes=1)
+@pytest.mark.parametrize(
+ ("num_epochs", "converged"),
+ [(100, True), (1, False)],
+)
+def test_fit_converged(num_epochs: int, converged: bool):
+ data, _ = sample_gmm(
+ num_datapoints=10000,
+ num_features=8,
+ num_components=4,
+ covariance_type="spherical",
+ )
+
+ estimator = KMeans(4, trainer_params={"max_epochs": num_epochs})
+ estimator.fit(data)
+
+ assert estimator.converged_ == converged
+
+
+@pytest.mark.flaky(max_runs=5, min_passes=1)
+@pytest.mark.parametrize(
+ ("num_datapoints", "batch_size", "num_features", "num_centroids"),
+ [
+ (10000, None, 8, 4),
+ (10000, 1000, 8, 4),
+ ],
+)
+def test_fit_inertia(
+ num_datapoints: int,
+ batch_size: Optional[int],
+ num_features: int,
+ num_centroids: int,
+):
+ data, _ = sample_gmm(
+ num_datapoints=num_datapoints,
+ num_features=num_features,
+ num_components=num_centroids,
+ covariance_type="spherical",
+ )
+
+ # Ours
+ estimator = KMeans(
+ num_centroids,
+ batch_size=batch_size,
+ )
+ ours_inertia = float("inf")
+ for _ in range(10):
+ ours_inertia = min(ours_inertia, estimator.fit(data).score(data))
+
+ # Sklearn
+ gmm = SklearnKMeans(num_centroids, n_init=10)
+ sklearn_inertia = gmm.fit(data.numpy()).score(data.numpy())
+
+ assert math.isclose(ours_inertia, -sklearn_inertia / data.size(0), rel_tol=0.01, abs_tol=0.01)
diff --git a/tests/clustering/kmeans/test_kmeans_model.py b/tests/clustering/kmeans/test_kmeans_model.py
new file mode 100644
index 0000000..5c52178
--- /dev/null
+++ b/tests/clustering/kmeans/test_kmeans_model.py
@@ -0,0 +1,30 @@
+# pylint: disable=missing-function-docstring
+import torch
+from torch import jit
+
+from torchgmm.clustering.kmeans import KMeansModel, KMeansModelConfig
+
+torch.set_num_threads(15)
+
+
+def test_compile():
+ config = KMeansModelConfig(num_clusters=2, num_features=5)
+ model = KMeansModel(config)
+ jit.script(model)
+
+
+def test_forward():
+ config = KMeansModelConfig(num_clusters=2, num_features=2)
+ model = KMeansModel(config)
+ model.centroids.copy_(torch.as_tensor([[0.0, 0.0], [2.0, 2.0]]))
+
+ X = torch.as_tensor([[0.0, 0.0], [1.0, 1.0], [2.0, 2.0], [-1.0, 4.0]])
+ distances, assignments, inertias = model.forward(X)
+
+ expected_distances = torch.as_tensor([[0.0, 8.0], [2.0, 2.0], [8.0, 0.0], [17.0, 13.0]]).sqrt()
+ expected_assignments = torch.as_tensor([0, 0, 1, 1])
+ expected_inertias = torch.as_tensor([0.0, 2.0, 0.0, 13.0])
+
+ assert torch.allclose(distances, expected_distances)
+ assert torch.all(assignments == expected_assignments)
+ assert torch.allclose(inertias, expected_inertias)
diff --git a/tests/test_basic.py b/tests/test_basic.py
deleted file mode 100644
index 63d315d..0000000
--- a/tests/test_basic.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import pytest
-
-import torchgmm
-
-
-def test_package_has_version():
- assert torchgmm.__version__ is not None
-
-
-@pytest.mark.skip(reason="This decorator should be removed when test passes.")
-def test_example():
- assert 1 == 0 # This test is designed to fail.