Skip to content

[MRG] CO-Optimal Transport solver #447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7ba5f03
Allow warmstart in sinkhorn and sinkhorn_log
6Ulm Feb 20, 2023
eabeabe
Added argument for warmstart of dual vectors in Sinkhorn-based method…
6Ulm Feb 22, 2023
cdd9373
Add the number of the PR
6Ulm Feb 22, 2023
f3d36b2
[WIP] CO-Optimal Transport
6Ulm Feb 22, 2023
53e50e5
Revert "[WIP] CO-Optimal Transport"
6Ulm Feb 22, 2023
96ea795
reformat with PEP8
6Ulm Feb 22, 2023
96449a0
Fix W291 trailing whitespace error in pep8 test
6Ulm Feb 23, 2023
49b4975
Merge branch 'master' into huy-tran
rflamary Feb 23, 2023
2571802
Rearange position of warmstart argument and edit its description
6Ulm Feb 23, 2023
60eeb46
Implementation of CO-Optimal Transport
6Ulm Feb 24, 2023
4521395
Merge branch 'master' into huy-tran
6Ulm Feb 24, 2023
217f98b
Optimize code and edit documentation
6Ulm Feb 24, 2023
020dcfe
Add number of PR
6Ulm Feb 24, 2023
bdbb64f
fix backend bug in test cases
6Ulm Feb 24, 2023
1954105
fix backend bug
6Ulm Feb 24, 2023
d851c6e
fix backend bug
6Ulm Feb 24, 2023
d5cd609
Add examples on COOT
6Ulm Feb 28, 2023
909aed8
Merge branch 'master' into huy-tran
6Ulm Feb 28, 2023
e647d85
Modify API and edit example
6Ulm Mar 14, 2023
b9bb185
Edit API
6Ulm Mar 14, 2023
fcb25a5
[WIP] CO-Optimal Transport
6Ulm Mar 14, 2023
0de914f
Merge branch 'PythonOT:master' into huy-tran
6Ulm Mar 16, 2023
0b477da
minor edit of examples and release
6Ulm Mar 16, 2023
282cac0
pull from repo
6Ulm Mar 16, 2023
eb84f61
fix bug in coot
6Ulm Mar 16, 2023
0057b8b
fix doc examples
agramfort Mar 18, 2023
2e98305
more fix of doc
agramfort Mar 18, 2023
7842457
restart CI
agramfort Mar 19, 2023
bccf896
Merge branch 'master' into huy-tran
rflamary Mar 21, 2023
17f7149
reordering ref
6Ulm Mar 21, 2023
a454f38
add more tests
6Ulm Mar 21, 2023
94a9db3
add more tests
6Ulm Mar 21, 2023
4eec61d
add test verbose
6Ulm Mar 21, 2023
62817a5
fix PEP8 bug
6Ulm Mar 21, 2023
f6b83f0
fix PEP8 bug
6Ulm Mar 21, 2023
72dca4b
fix PEP8 bug
6Ulm Mar 21, 2023
e82c624
fix pytest bug
6Ulm Mar 21, 2023
1de9c36
edit doc for better display
6Ulm Mar 21, 2023
23ae989
Merge branch 'master' into huy-tran
rflamary Mar 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,15 @@ You can also post bug reports and feature requests in Github issues. Make sure t

[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).

[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
Machine Learning (pp. 4104-4113). PMLR.

[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International
Conference on Machine Learning, PMLR 119:4692-4701, 2020

[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.

[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
Expand All @@ -305,4 +305,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer

[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787.

[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022.
[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022.

[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33.
6 changes: 4 additions & 2 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
- New API for OT solver using function `ot.solve` (PR #388)
- Backend version of `ot.partial` and `ot.smooth` (PR #388 and #449)
- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
- Add parameters method in `ot.da.SinkhornTransport` (PR #440)
- `ot.dr` now uses the new Pymanopt API and POT is compatible with current Pymanopt (PR #443)
- Added parameters method in `ot.da.SinkhornTransport` (PR #440)
- `ot.dr` now uses the new Pymanopt API and POT is compatible with current
Pymanopt (PR #443)
- Added CO-Optimal Transport solver + examples (PR # 447)
- Remove the redundant `nx.abs()` at the end of `wasserstein_1d()` (PR #448)

#### Closed issues
Expand Down
1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ API and modules

backend
bregman
coot
da
datasets
dr
Expand Down
97 changes: 97 additions & 0 deletions examples/others/plot_COOT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
r"""
===================================================
Row and column alignments with CO-Optimal Transport
===================================================

This example is designed to show how to use the CO-Optimal Transport [47]_ in POT.
CO-Optimal Transport allows to calculate the distance between two **arbitrary-size**
matrices, and to align their rows and columns. In this example, we consider two
random matrices :math:`X_1` and :math:`X_2` defined by
:math:`(X_1)_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi) + \sigma \mathcal N(0,1)`
and :math:`(X_2)_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi) + \sigma \mathcal N(0,1)`.

.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
`CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
Advances in Neural Information Processing Systems, 33.
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
# License: MIT License

from matplotlib.patches import ConnectionPatch
import matplotlib.pylab as pl
import numpy as np
from ot.coot import co_optimal_transport as coot
from ot.coot import co_optimal_transport2 as coot2

# %%
# Generating two random matrices

n1 = 20
n2 = 10
d1 = 16
d2 = 8
sigma = 0.2

X1 = (
np.cos(np.arange(n1) * np.pi / n1)[:, None] +
np.cos(np.arange(d1) * np.pi / d1)[None, :] +
sigma * np.random.randn(n1, d1)
)
X2 = (
np.cos(np.arange(n2) * np.pi / n2)[:, None] +
np.cos(np.arange(d2) * np.pi / d2)[None, :] +
sigma * np.random.randn(n2, d2)
)

# %%
# Visualizing the matrices

pl.figure(1, (8, 5))
pl.subplot(1, 2, 1)
pl.imshow(X1)
pl.title('$X_1$')

pl.subplot(1, 2, 2)
pl.imshow(X2)
pl.title("$X_2$")

pl.tight_layout()

# %%
# Visualizing the alignments of rows and columns, and calculating the CO-Optimal Transport distance

pi_sample, pi_feature, log = coot(X1, X2, log=True, verbose=True)
coot_distance = coot2(X1, X2)
print('CO-Optimal Transport distance = {:.5f}'.format(coot_distance))

fig = pl.figure(4, (9, 7))
pl.clf()

ax1 = pl.subplot(2, 2, 3)
pl.imshow(X1)
pl.xlabel('$X_1$')

ax2 = pl.subplot(2, 2, 2)
ax2.yaxis.tick_right()
pl.imshow(np.transpose(X2))
pl.title("Transpose($X_2$)")
ax2.xaxis.tick_top()

for i in range(n1):
j = np.argmax(pi_sample[i, :])
xyA = (d1 - .5, i)
xyB = (j, d2 - .5)
con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
coordsB=ax2.transData, color="black")
fig.add_artist(con)

for i in range(d1):
j = np.argmax(pi_feature[i, :])
xyA = (i, -.5)
xyB = (-.5, j)
con = ConnectionPatch(
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
fig.add_artist(con)
150 changes: 150 additions & 0 deletions examples/others/plot_learning_weights_with_COOT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
r"""
===============================================================
Learning sample marginal distribution with CO-Optimal Transport
===============================================================

In this example, we illustrate how to estimate the sample marginal distribution which minimizes
the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data
:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed
histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem

.. math::
\min_{\mu_y^{(s)} \in \Delta} \text{COOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right)

where :math:`\Delta` is the probability simplex. This minimization is done with a
simple projected gradient descent in PyTorch. We use the automatic backend of POT that
allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2`
with differentiable losses.

.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
`CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
Advances in Neural Information Processing Systems, 33.
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
# License: MIT License

from matplotlib.patches import ConnectionPatch
import torch
import numpy as np

import matplotlib.pyplot as pl
import ot

from ot.coot import co_optimal_transport as coot
from ot.coot import co_optimal_transport2 as coot2


# %%
# Generate data
# -------------
# The source and clean target matrices are generated by
# :math:`X_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi)` and
# :math:`Y_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi)`.
# The target matrix is then contaminated by adding 5 row outliers.
# Intuitively, we expect that the estimated sample distribution should ignore these outliers,
# i.e. their weights should be zero.

np.random.seed(182)

n1, d1 = 20, 16
n2, d2 = 10, 8
n = 15

X = (
torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] +
torch.cos(torch.arange(d1) * torch.pi / d1)[None, :]
)

# Generate clean target data mixed with outliers
Y_noisy = torch.randn((n, d2)) * 10.0
Y_noisy[:n2, :] = (
torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] +
torch.cos(torch.arange(d2) * torch.pi / d2)[None, :]
)
Y = Y_noisy[:n2, :]

X, Y_noisy, Y = X.double(), Y_noisy.double(), Y.double()

fig, axes = pl.subplots(nrows=1, ncols=3, figsize=(12, 5))
axes[0].imshow(X, vmin=-2, vmax=2)
axes[0].set_title('$X$')

axes[1].imshow(Y, vmin=-2, vmax=2)
axes[1].set_title('Clean $Y$')

axes[2].imshow(Y_noisy, vmin=-2, vmax=2)
axes[2].set_title('Noisy $Y$')

pl.tight_layout()

# %%
# Optimize the COOT distance with respect to the sample marginal distribution
# ---------------------------------------------------------------------------

losses = []
lr = 1e-3
niter = 1000

b = torch.tensor(ot.unif(n), requires_grad=True)

for i in range(niter):

loss = coot2(X, Y_noisy, wy_samp=b, log=False, verbose=False)
losses.append(float(loss))

loss.backward()

with torch.no_grad():
b -= lr * b.grad # gradient step
b[:] = ot.utils.proj_simplex(b) # projection on the simplex

b.grad.zero_()

# Estimated sample marginal distribution and training loss curve
pl.plot(losses[10:])
pl.title('CO-Optimal Transport distance')

print(f"Marginal distribution = {b.detach().numpy()}")

# %%
# Visualizing the row and column alignments with the estimated sample marginal distribution
# -----------------------------------------------------------------------------------------
#
# Clearly, the learned marginal distribution completely and successfully ignores the 5 outliers.

X, Y_noisy = X.numpy(), Y_noisy.numpy()
b = b.detach().numpy()

pi_sample, pi_feature = coot(X, Y_noisy, wy_samp=b, log=False, verbose=True)

fig = pl.figure(4, (9, 7))
pl.clf()

ax1 = pl.subplot(2, 2, 3)
pl.imshow(X, vmin=-2, vmax=2)
pl.xlabel('$X$')

ax2 = pl.subplot(2, 2, 2)
ax2.yaxis.tick_right()
pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2)
pl.title("Transpose(Noisy $Y$)")
ax2.xaxis.tick_top()

for i in range(n1):
j = np.argmax(pi_sample[i, :])
xyA = (d1 - .5, i)
xyB = (j, d2 - .5)
con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
coordsB=ax2.transData, color="black")
fig.add_artist(con)

for i in range(d1):
j = np.argmax(pi_feature[i, :])
xyA = (i, -.5)
xyB = (-.5, j)
con = ConnectionPatch(
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
fig.add_artist(con)
Loading