-
Notifications
You must be signed in to change notification settings - Fork 502
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[WIP] Implementation of FUGW and UCOOT (#677)
* implementation of FUGW and UCOOT * fix pep8 error * fix test_utils error * remove print * add documentation and fix bug * first code review * fix documentation
- Loading branch information
Showing
15 changed files
with
3,180 additions
and
118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,16 @@ | ||
# -*- coding: utf-8 -*- | ||
r""" | ||
=============================================================== | ||
Learning sample marginal distribution with CO-Optimal Transport | ||
=============================================================== | ||
====================================================================================================================================== | ||
Detecting outliers by learning sample marginal distribution with CO-Optimal Transport and by using unbalanced Co-Optimal Transport | ||
====================================================================================================================================== | ||
In this example, we illustrate how to estimate the sample marginal distribution which minimizes | ||
the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data | ||
In this example, we consider two point clouds living in different Euclidean spaces, where the outliers | ||
are artifically injected into the target data. We illustrate two methods which allow to filter out | ||
these outliers. | ||
The first method requires learning the sample marginal distribution which minimizes | ||
the CO-Optimal Transport distance [49] between two input spaces. | ||
More precisely, given a source data | ||
:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed | ||
histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem | ||
|
@@ -17,9 +22,19 @@ | |
allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2` | ||
with differentiable losses. | ||
The second method simply requires direct application of unbalanced Co-Optimal Transport [71]. | ||
More precisely, it is enough to use the sample and feature coupling from solving | ||
.. math:: | ||
\text{UCOOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right) | ||
where all the marginal distributions are uniform. | ||
.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). | ||
`CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_. | ||
Advances in Neural Information Processing Systems, 33. | ||
.. [71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193). | ||
AAAI Conference on Artificial Intelligence. | ||
""" | ||
|
||
# Author: Remi Flamary <[email protected]> | ||
|
@@ -35,6 +50,7 @@ | |
|
||
from ot.coot import co_optimal_transport as coot | ||
from ot.coot import co_optimal_transport2 as coot2 | ||
from ot.gromov._unbalanced import unbalanced_co_optimal_transport | ||
|
||
|
||
# %% | ||
|
@@ -148,3 +164,49 @@ | |
con = ConnectionPatch( | ||
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") | ||
fig.add_artist(con) | ||
|
||
# %% | ||
# Now, let see if we can use unbalanced Co-Optimal Transport to recover the clean OT plans, | ||
# without the need of learning the marginal distribution as in Co-Optimal Transport. | ||
# ----------------------------------------------------------------------------------------- | ||
|
||
pi_sample, pi_feature = unbalanced_co_optimal_transport( | ||
X=X, Y=Y_noisy, reg_marginals=(10, 10), epsilon=0, divergence="kl", | ||
unbalanced_solver="mm", max_iter=1000, tol=1e-6, | ||
max_iter_ot=1000, tol_ot=1e-6, log=False, verbose=False | ||
) | ||
|
||
# %% | ||
# Visualizing the row and column alignments learned by unbalanced Co-Optimal Transport. | ||
# ----------------------------------------------------------------------------------------- | ||
# | ||
# Similar to Co-Optimal Transport, we are also be able to fully recover the clean OT plans. | ||
|
||
fig = pl.figure(4, (9, 7)) | ||
pl.clf() | ||
|
||
ax1 = pl.subplot(2, 2, 3) | ||
pl.imshow(X, vmin=-2, vmax=2) | ||
pl.xlabel('$X$') | ||
|
||
ax2 = pl.subplot(2, 2, 2) | ||
ax2.yaxis.tick_right() | ||
pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2) | ||
pl.title("Transpose(Noisy $Y$)") | ||
ax2.xaxis.tick_top() | ||
|
||
for i in range(n1): | ||
j = np.argmax(pi_sample[i, :]) | ||
xyA = (d1 - .5, i) | ||
xyB = (j, d2 - .5) | ||
con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, | ||
coordsB=ax2.transData, color="black") | ||
fig.add_artist(con) | ||
|
||
for i in range(d1): | ||
j = np.argmax(pi_feature[i, :]) | ||
xyA = (i, -.5) | ||
xyB = (-.5, j) | ||
con = ConnectionPatch( | ||
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") | ||
fig.add_artist(con) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,13 +6,16 @@ | |
|
||
# Author: Remi Flamary <[email protected]> | ||
# Cedric Vincent-Cuaz <[email protected]> | ||
# Quang Huy Tran <[email protected]> | ||
# | ||
# License: MIT License | ||
|
||
# All submodules and packages | ||
from ._utils import (init_matrix, tensor_product, gwloss, gwggrad, | ||
init_matrix_semirelaxed, semirelaxed_init_plan, | ||
update_barycenter_structure, update_barycenter_feature) | ||
update_barycenter_structure, update_barycenter_feature, | ||
div_between_product, div_to_product, fused_unbalanced_across_spaces_cost, | ||
uot_cost_matrix, uot_parameters_and_measures) | ||
|
||
from ._gw import (gromov_wasserstein, gromov_wasserstein2, | ||
fused_gromov_wasserstein, fused_gromov_wasserstein2, | ||
|
@@ -63,9 +66,17 @@ | |
quantized_fused_gromov_wasserstein_samples | ||
) | ||
|
||
from ._unbalanced import (fused_unbalanced_gromov_wasserstein, | ||
fused_unbalanced_gromov_wasserstein2, | ||
unbalanced_co_optimal_transport, | ||
unbalanced_co_optimal_transport2, | ||
fused_unbalanced_across_spaces_divergence) | ||
|
||
__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', | ||
'init_matrix_semirelaxed', 'semirelaxed_init_plan', | ||
'update_barycenter_structure', 'update_barycenter_feature', | ||
'div_between_product', 'div_to_product', 'fused_unbalanced_across_spaces_cost', | ||
'uot_cost_matrix', 'uot_parameters_and_measures', | ||
'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein', | ||
'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters', | ||
'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2', | ||
|
@@ -87,4 +98,7 @@ | |
'get_graph_representants', 'format_partitioned_graph', | ||
'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples', | ||
'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples', | ||
'fused_unbalanced_gromov_wasserstein', 'fused_unbalanced_gromov_wasserstein2', | ||
'unbalanced_co_optimal_transport', 'unbalanced_co_optimal_transport2', | ||
'fused_unbalanced_across_spaces_divergence' | ||
] |
Oops, something went wrong.