-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
458ba5e
commit 71849f8
Showing
70 changed files
with
4,412 additions
and
331 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,25 @@ | ||
# API | ||
|
||
## Preprocessing | ||
## Bayesian Models | ||
|
||
```{eval-rst} | ||
.. module:: torchgmm.pp | ||
.. module:: torchgmm.bayes | ||
.. currentmodule:: torchgmm | ||
.. autosummary:: | ||
:toctree: generated | ||
pp.basic_preproc | ||
bayes.GaussianMixture | ||
``` | ||
|
||
## Tools | ||
## Clustering Models | ||
|
||
```{eval-rst} | ||
.. module:: torchgmm.tl | ||
.. module:: torchgmm.clustering | ||
.. currentmodule:: torchgmm | ||
.. autosummary:: | ||
:toctree: generated | ||
tl.basic_tool | ||
``` | ||
|
||
## Plotting | ||
|
||
```{eval-rst} | ||
.. module:: torchgmm.pl | ||
.. currentmodule:: torchgmm | ||
.. autosummary:: | ||
:toctree: generated | ||
pl.basic_plot | ||
pl.BasicClass | ||
clustering.KMeans | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Benchmarks | ||
|
||
This benchmark is based on PyCave's benchmarking, where they evaluated the runtime performance of TorchGMM by running an exhaustive set of experiments to compare against the implementation found in scikit-learn. Evaluations are run at varying dataset sizes. | ||
|
||
All benchmarks are run on an instance with a Intel Xeon E5-2630 v4 CPU (2.2 GHz). They use at most 4 | ||
cores and 60 GiB of memory. Also, there is a single GeForce GTX 1080 Ti GPU (11 GiB memory) | ||
available. For the performance measures, each benchmark is run at least 5 times. | ||
|
||
## Gaussian Mixture | ||
|
||
### Setup | ||
|
||
For measuring the performance of fitting a Gaussian mixture model, they fix the number of iterations | ||
after initialization to 100 to not measure any variances in the convergence criterion. For | ||
initialization, they further set the known means that were used to generate data to not run into | ||
issues of degenerate covariance matrices. Thus, all benchmarks essentially measure the performance | ||
after K-means initialization has been run. Benchmarks for K-means itself are listed below. | ||
|
||
### Results | ||
|
||
| | Scikit-Learn | TorchGMM CPU (full) | TorchGMM CPU (batches) | TorchGMM GPU (full) | TorchGMM GPU (batches) | | ||
| ------------------ | ------------ | ------------------- | ---------------------- | ------------------- | ---------------------- | | ||
| `[10k, 8] -> 4` | **352 ms** | 649 ms | 3.9 s | 358 ms | 3.6 s | | ||
| `[100k, 32] -> 16` | 18.4 s | 4.3 s | 10.0 s | **527 ms** | 3.9 s | | ||
| `[1M, 64] -> 64` | 730 s | 196 s | 284 s | **7.7 s** | 15.3 s | | ||
|
||
Training Duration for Diagonal Covariance (`[num_datapoints, num_features] -> num_components`) | ||
|
||
| | Scikit-Learn | TorchGMM CPU (full) | TorchGMM CPU (batches) | TorchGMM GPU (full) | TorchGMM GPU (batches) | | ||
| ------------------ | ------------ | ------------------- | ---------------------- | ------------------- | ---------------------- | | ||
| `[10k, 8] -> 4` | 699 ms | 570 ms | 3.6 s | **356 ms** | 3.3 s | | ||
| `[100k, 32] -> 16` | 72.2 s | 12.1 s | 16.1 s | **919 ms** | 3.8 s | | ||
| `[1M, 64] -> 64` | -- | -- | -- | -- | **63.4 s** | | ||
|
||
Training Duration for Tied Covariance (`[num_datapoints, num_features] -> num_components`) | ||
|
||
| | Scikit-Learn | TorchGMM CPU (full) | TorchGMM CPU (batches) | TorchGMM GPU (full) | TorchGMM GPU (batches) | | ||
| ------------------ | ------------ | ------------------- | ---------------------- | ------------------- | ---------------------- | | ||
| `[10k, 8] -> 4` | 1.1 s | 679 ms | 4.1 s | **648 ms** | 4.4 s | | ||
| `[100k, 32] -> 16` | 110 s | 13.5 s | 21.2 s | **2.4 s** | 7.8 s | | ||
|
||
Training Duration for Full Covariance (`[num_datapoints, num_features] -> num_components`) | ||
|
||
### Summary | ||
|
||
TorchGMM's implementation of the Gaussian mixture model is markedly more efficient than the one found | ||
in scikit-learn. Even on the CPU, TorchGMM outperforms scikit-learn significantly at a 100k | ||
datapoints already. When moving to the GPU, however, TorchGMM unfolds its full potential and yields | ||
speed ups at around 100x. For larger datasets, mini-batch training is the only alternative. TorchGMM | ||
fully supports that while the training is approximately twice as large as when training using the | ||
full data. The reason for this is that the M-step of the EM algorithm needs to be split across | ||
epochs, which, in turn, requires to replay the E-step. | ||
|
||
## K-Means | ||
|
||
### Setup | ||
|
||
For the scikit-learn implementation, they use Lloyd's algorithm instead of Elkan's algorithm to have | ||
a useful comparison with TorchGMM (which implements Lloyd's algorithm). | ||
|
||
Further, they fix the number of iterations after initialization to 100 to not measure any variances | ||
in the convergence criterion. | ||
|
||
### Results | ||
|
||
| | Scikit-Learn | TorchGMM CPU (full) | TorchGMM CPU (batches) | TorchGMM GPU (full) | TorchGMM GPU (batches) | | ||
| ------------------- | ------------ | ------------------- | ---------------------- | ------------------- | ---------------------- | | ||
| `[10k, 8] -> 4` | **13 ms** | 412 ms | 797 ms | 387 ms | 2.1 s | | ||
| `[100k, 32] -> 16` | **311 ms** | 2.1 s | 3.4 s | 707 ms | 2.5 s | | ||
| `[1M, 64] -> 64` | 10.0 s | 73.6 s | 58.1 s | **8.2 s** | 10.0 s | | ||
| `[10M, 128] -> 128` | 254 s | -- | -- | -- | **133 s** | | ||
|
||
Training Duration for Random Initialization (`[num_datapoints, num_features] -> num_clusters`) | ||
|
||
| | Scikit-Learn | TorchGMM CPU (full) | TorchGMM CPU (batches) | TorchGMM GPU (full) | TorchGMM GPU (batches) | | ||
| ------------------- | ------------ | ------------------- | ---------------------- | ------------------- | ---------------------- | | ||
| `[10k, 8] -> 4` | **15 ms** | 170 ms | 930 ms | 431 ms | 2.4 s | | ||
| `[100k, 32] -> 16` | **542 ms** | 2.3 s | 4.3 s | 840 ms | 3.2 s | | ||
| `[1M, 64] -> 64` | 25.3 s | 93.4 s | 83.7 s | **13.1 s** | 17.1 s | | ||
| `[10M, 128] -> 128` | 827 s | -- | -- | -- | **369 s** | | ||
|
||
Training Duration for K-Means++ Initialization (`[num_datapoints, num_features] -> num_clusters`) | ||
|
||
### Summary | ||
|
||
As it turns out, it is really hard to outperform the implementation found in scikit-learn. | ||
Especially if little data is available, the overhead of PyTorch and PyTorch Lightning renders | ||
TorchGMM comparatively slow. However, as more data is available, TorchGMM starts to become relatively | ||
faster and, when leveraging the GPU, it finally outperforms scikit-learn for a dataset size of 1M | ||
datapoints. Nonetheless, the improvement is marginal. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,5 @@ | |
api.md | ||
changelog.md | ||
contributing.md | ||
references.md | ||
notebooks/example | ||
benchmark.md | ||
``` |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +0,0 @@ | ||
@article{Virshup_2023, | ||
doi = {10.1038/s41587-023-01733-8}, | ||
url = {https://doi.org/10.1038%2Fs41587-023-01733-8}, | ||
year = 2023, | ||
month = {apr}, | ||
publisher = {Springer Science and Business Media {LLC}}, | ||
author = {Isaac Virshup and Danila Bredikhin and Lukas Heumos and Giovanni Palla and Gregor Sturm and Adam Gayoso and Ilia Kats and Mikaela Koutrouli and Philipp Angerer and Volker Bergen and Pierre Boyeau and Maren Büttner and Gokcen Eraslan and David Fischer and Max Frank and Justin Hong and Michal Klein and Marius Lange and Romain Lopez and Mohammad Lotfollahi and Malte D. Luecken and Fidel Ramirez and Jeffrey Regier and Sergei Rybakov and Anna C. Schaar and Valeh Valiollah Pour Amiri and Philipp Weiler and Galen Xing and Bonnie Berger and Dana Pe'er and Aviv Regev and Sarah A. Teichmann and Francesca Finotello and F. Alexander Wolf and Nir Yosef and Oliver Stegle and Fabian J. Theis and}, | ||
title = {The scverse project provides a computational ecosystem for single-cell omics data analysis}, | ||
journal = {Nature Biotechnology} | ||
} | ||
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.