Skip to content

Commit

Permalink
[WIP] Implementation of FUGW and UCOOT (#677)
Browse files Browse the repository at this point in the history
* implementation of FUGW and UCOOT

* fix pep8 error

* fix test_utils error

* remove print

* add documentation and fix bug

* first code review

* fix documentation
  • Loading branch information
6Ulm authored Oct 14, 2024
1 parent 2aa8338 commit 791137b
Show file tree
Hide file tree
Showing 15 changed files with 3,180 additions and 118 deletions.
17 changes: 14 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ POT provides the following generic OT solvers (links to examples):
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
* [Smooth Strongly Convex Nearest Brenier Potentials](https://pythonot.github.io/auto_examples/others/plot_SSNB.html#sphx-glr-auto-examples-others-plot-ssnb-py) [58], with an extension to bounding potentials using [59].
* Gaussian Mixture Model OT [69]
* [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and
[unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71].
* Fused unbalanced Gromov-Wasserstein [70].

POT provides the following Machine Learning related solvers:

Expand All @@ -62,7 +65,7 @@ POT provides the following Machine Learning related solvers:
* [Linear OT mapping](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_linear_mapping.html) [14] and [Joint OT mapping estimation](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_mapping.html) [8].
* [Wasserstein Discriminant Analysis](https://pythonot.github.io/auto_examples/others/plot_WDA.html) [11] (requires autograd + pymanopt).
* [JCPOT algorithm for multi-source domain adaptation with target shift](https://pythonot.github.io/auto_examples/domain-adaptation/plot_otda_jcpot.html) [27].
* [Graph Neural Network OT layers TFGW](https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html) [52] and TW (OT-GNN) [53]
* [Graph Neural Network OT layers TFGW](https://pythonot.github.io/auto_examples/gromov/plot_gnn_TFGW.html) [52] and TW (OT-GNN) [53]

Some other examples are available in the [documentation](https://pythonot.github.io/auto_examples/index.html).

Expand Down Expand Up @@ -198,7 +201,7 @@ This toolbox has been created by
* [Rémi Flamary](https://remi.flamary.com/)
* [Nicolas Courty](http://people.irisa.fr/Nicolas.Courty/)

It is currently maintained by
It is currently maintained by

* [Rémi Flamary](https://remi.flamary.com/)
* [Cédric Vincent-Cuaz](https://cedricvincentcuaz.github.io/)
Expand Down Expand Up @@ -370,4 +373,12 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil

[68] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing.

[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970.
[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970.

[70] A. Thual, H. Tran, T. Zemskova, N. Courty, R. Flamary, S. Dehaene
& B. Thirion (2022). [Aligning individual brains with Fused Unbalanced Gromov-Wasserstein.](https://proceedings.neurips.cc/paper_files/paper/2022/file/8906cac4ca58dcaf17e97a0486ad57ca-Paper-Conference.pdf). Neural Information Processing Systems (NeurIPS).

[71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193). AAAI Conference on
Artificial Intelligence.

[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS).
25 changes: 13 additions & 12 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
## 0.9.5dev

#### New features
- Add feature `mass=True` for `nx.kl_div` (PR #654)
- Gaussian Mixture Model OT `ot.gmm` (PR #649)
- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659)
- Add initialization heuristics for sr(F)GW problems via `semirelaxed_init_plan`, integrated in all sr(F)GW solvers (PR #659)
- Added feature `mass=True` for `nx.kl_div` (PR #654)
- Implemented Gaussian Mixture Model OT `ot.gmm` (PR #649)
- Added feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659)
- Added initialization heuristics for sr(F)GW problems via `semirelaxed_init_plan`, integrated in all sr(F)GW solvers (PR #659)
- Improved `ot.plot.plot1D_mat` (PR #649)
- Added `nx.det` (PR #649)
- `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649)
- restructure `ot.unbalanced` module (PR #658)
- add `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658)
- Restructured `ot.unbalanced` module (PR #658)
- Added `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658)
- Implemented Fused unbalanced Gromov-Wasserstein and unbalanced Co-Optimal Transport (PR #677)

#### Closed issues
- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)
Expand Down Expand Up @@ -72,7 +73,7 @@ xs, xt = np.random.randn(100, 2), np.random.randn(50, 2)

# Solve OT problem with empirical samples
sol = ot.solve_sample(xs, xt) # Exact OT betwen smaples with uniform weights
sol = ot.solve_sample(xs, xt, wa, wb) # Exact OT with weights given by user
sol = ot.solve_sample(xs, xt, wa, wb) # Exact OT with weights given by user

sol = ot.solve_sample(xs, xt, reg= 1, metric='euclidean') # sinkhorn with euclidean metric

Expand All @@ -84,15 +85,15 @@ sol = ot.solve_sample(x,x2, method='lowrank', rank=10) # compute lowrank sinkhor

value_bw = ot.solve_sample(xs, xt, method='gaussian').value # Bures-Wasserstein distance

# Solve GW problem
# Solve GW problem
Cs, Ct = ot.dist(xs, xs), ot.dist(xt, xt) # compute cost matrices
sol = ot.solve_gromov(Cs,Ct) # Exact GW between samples with uniform weights

# Solve FGW problem
M = ot.dist(xs, xt) # compute cost matrix

# Exact FGW between samples with uniform weights
sol = ot.solve_gromov(Cs, Ct, M, loss='KL', alpha=0.7) # FGW with KL data fitting
sol = ot.solve_gromov(Cs, Ct, M, loss='KL', alpha=0.7) # FGW with KL data fitting


# recover solutions objects
Expand All @@ -102,14 +103,14 @@ value = sol.value # OT value

# for GW and FGW
value_linear = sol.value_linear # linear part of the loss
value_quad = sol.value_quad # quadratic part of the loss
value_quad = sol.value_quad # quadratic part of the loss

```

Users are encouraged to use the new API (it is much simpler) but it might still be subjects to small changes before the release of POT 1.0 .


We also fixed a number of issues, the most pressing being a problem of GPU memory allocation when pytorch is installed that will not happen now thanks to Lazy initialization of the backends. We now also have the possibility to deactivate some backends using environment which prevents POT from importing them and can lead to large import speedup.
We also fixed a number of issues, the most pressing being a problem of GPU memory allocation when pytorch is installed that will not happen now thanks to Lazy initialization of the backends. We now also have the possibility to deactivate some backends using environment which prevents POT from importing them and can lead to large import speedup.


#### New features
Expand Down Expand Up @@ -143,7 +144,7 @@ We also fixed a number of issues, the most pressing being a problem of GPU memor
- Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #566)
- Avoid precision change when computing norm using PyTorch backend (Discussion #570, PR #572)
- Create `ot/bregman/`repository (Issue #567, PR #569)
- Fix matrix feature shape in `entropic_fused_gromov_barycenters`(Issue #574, PR #573)
- Fix matrix feature shape in `entropic_fused_gromov_barycenters`(Issue #574, PR #573)
- Fix (fused) gromov-wasserstein barycenter solvers to support `kl_loss`(PR #576)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# -*- coding: utf-8 -*-
r"""
===============================================================
Learning sample marginal distribution with CO-Optimal Transport
===============================================================
======================================================================================================================================
Detecting outliers by learning sample marginal distribution with CO-Optimal Transport and by using unbalanced Co-Optimal Transport
======================================================================================================================================
In this example, we illustrate how to estimate the sample marginal distribution which minimizes
the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data
In this example, we consider two point clouds living in different Euclidean spaces, where the outliers
are artifically injected into the target data. We illustrate two methods which allow to filter out
these outliers.
The first method requires learning the sample marginal distribution which minimizes
the CO-Optimal Transport distance [49] between two input spaces.
More precisely, given a source data
:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed
histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem
Expand All @@ -17,9 +22,19 @@
allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2`
with differentiable losses.
The second method simply requires direct application of unbalanced Co-Optimal Transport [71].
More precisely, it is enough to use the sample and feature coupling from solving
.. math::
\text{UCOOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right)
where all the marginal distributions are uniform.
.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
`CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
Advances in Neural Information Processing Systems, 33.
.. [71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193).
AAAI Conference on Artificial Intelligence.
"""

# Author: Remi Flamary <[email protected]>
Expand All @@ -35,6 +50,7 @@

from ot.coot import co_optimal_transport as coot
from ot.coot import co_optimal_transport2 as coot2
from ot.gromov._unbalanced import unbalanced_co_optimal_transport


# %%
Expand Down Expand Up @@ -148,3 +164,49 @@
con = ConnectionPatch(
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
fig.add_artist(con)

# %%
# Now, let see if we can use unbalanced Co-Optimal Transport to recover the clean OT plans,
# without the need of learning the marginal distribution as in Co-Optimal Transport.
# -----------------------------------------------------------------------------------------

pi_sample, pi_feature = unbalanced_co_optimal_transport(
X=X, Y=Y_noisy, reg_marginals=(10, 10), epsilon=0, divergence="kl",
unbalanced_solver="mm", max_iter=1000, tol=1e-6,
max_iter_ot=1000, tol_ot=1e-6, log=False, verbose=False
)

# %%
# Visualizing the row and column alignments learned by unbalanced Co-Optimal Transport.
# -----------------------------------------------------------------------------------------
#
# Similar to Co-Optimal Transport, we are also be able to fully recover the clean OT plans.

fig = pl.figure(4, (9, 7))
pl.clf()

ax1 = pl.subplot(2, 2, 3)
pl.imshow(X, vmin=-2, vmax=2)
pl.xlabel('$X$')

ax2 = pl.subplot(2, 2, 2)
ax2.yaxis.tick_right()
pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2)
pl.title("Transpose(Noisy $Y$)")
ax2.xaxis.tick_top()

for i in range(n1):
j = np.argmax(pi_sample[i, :])
xyA = (d1 - .5, i)
xyB = (j, d2 - .5)
con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
coordsB=ax2.transData, color="black")
fig.add_artist(con)

for i in range(d1):
j = np.argmax(pi_feature[i, :])
xyA = (i, -.5)
xyB = (-.5, j)
con = ConnectionPatch(
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
fig.add_artist(con)
16 changes: 15 additions & 1 deletion ot/gromov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@

# Author: Remi Flamary <[email protected]>
# Cedric Vincent-Cuaz <[email protected]>
# Quang Huy Tran <[email protected]>
#
# License: MIT License

# All submodules and packages
from ._utils import (init_matrix, tensor_product, gwloss, gwggrad,
init_matrix_semirelaxed, semirelaxed_init_plan,
update_barycenter_structure, update_barycenter_feature)
update_barycenter_structure, update_barycenter_feature,
div_between_product, div_to_product, fused_unbalanced_across_spaces_cost,
uot_cost_matrix, uot_parameters_and_measures)

from ._gw import (gromov_wasserstein, gromov_wasserstein2,
fused_gromov_wasserstein, fused_gromov_wasserstein2,
Expand Down Expand Up @@ -63,9 +66,17 @@
quantized_fused_gromov_wasserstein_samples
)

from ._unbalanced import (fused_unbalanced_gromov_wasserstein,
fused_unbalanced_gromov_wasserstein2,
unbalanced_co_optimal_transport,
unbalanced_co_optimal_transport2,
fused_unbalanced_across_spaces_divergence)

__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad',
'init_matrix_semirelaxed', 'semirelaxed_init_plan',
'update_barycenter_structure', 'update_barycenter_feature',
'div_between_product', 'div_to_product', 'fused_unbalanced_across_spaces_cost',
'uot_cost_matrix', 'uot_parameters_and_measures',
'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
Expand All @@ -87,4 +98,7 @@
'get_graph_representants', 'format_partitioned_graph',
'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples',
'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples',
'fused_unbalanced_gromov_wasserstein', 'fused_unbalanced_gromov_wasserstein2',
'unbalanced_co_optimal_transport', 'unbalanced_co_optimal_transport2',
'fused_unbalanced_across_spaces_divergence'
]
Loading

0 comments on commit 791137b

Please sign in to comment.