Skip to content

Commit a5e0f0d

Browse files
authored
[MRG] Add weak OT solver (#341)
* add info in release file * update tests * pep8 * add weak OT example * update plot in doc * correction ewample with empirical sinkhorn * better thumbnail * comment from review * update documenation
1 parent 71a57c6 commit a5e0f0d

File tree

15 files changed

+343
-26
lines changed

15 files changed

+343
-26
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ POT provides the following generic OT solvers (links to examples):
2525
* Sinkhorn divergence [23] and entropic regularization OT from empirical data.
2626
* Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37]
2727
* [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17].
28+
* Weak OT solver between empirical distributions [39]
2829
* Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale).
2930
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from
3031
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
@@ -301,3 +302,5 @@ Conference on Machine Learning, PMLR 119:4692-4701, 2020
301302

302303
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
303304
Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.
305+
306+
[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.

RELEASES.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55

66
#### New features
77

8-
- Better list of related examples in quick start guide with `minigallery` (PR #334)
8+
- Better list of related examples in quick start guide with `minigallery` (PR #334).
99
- Add optional log-domain Sinkhorn implementation in WDA to support smaller values
10-
of the regularization parameter (PR #336)
11-
- Backend implementation for `ot.lp.free_support_barycenter` (PR #340)
10+
of the regularization parameter (PR #336).
11+
- Backend implementation for `ot.lp.free_support_barycenter` (PR #340).
12+
- Add weak OT solver + example (PR #341).
13+
1214

1315
#### Closed issues
1416

docs/source/all.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ API and modules
2828
unbalanced
2929
partial
3030
sliced
31+
weak
3132

3233
.. autosummary::
3334
:toctree: ../modules/generated/

examples/others/plot_WeakOT_VS_OT.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
====================================================
4+
Weak Optimal Transport VS exact Optimal Transport
5+
====================================================
6+
7+
Illustration of 2D optimal transport between distributions that are weighted
8+
sum of diracs. The OT matrix is plotted with the samples.
9+
10+
"""
11+
12+
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
13+
#
14+
# License: MIT License
15+
16+
# sphinx_gallery_thumbnail_number = 4
17+
18+
import numpy as np
19+
import matplotlib.pylab as pl
20+
import ot
21+
import ot.plot
22+
23+
##############################################################################
24+
# Generate data an plot it
25+
# ------------------------
26+
27+
#%% parameters and data generation
28+
29+
n = 50 # nb samples
30+
31+
mu_s = np.array([0, 0])
32+
cov_s = np.array([[1, 0], [0, 1]])
33+
34+
mu_t = np.array([4, 4])
35+
cov_t = np.array([[1, -.8], [-.8, 1]])
36+
37+
xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
38+
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
39+
40+
a, b = ot.unif(n), ot.unif(n) # uniform distribution on samples
41+
42+
# loss matrix
43+
M = ot.dist(xs, xt)
44+
M /= M.max()
45+
46+
#%% plot samples
47+
48+
pl.figure(1)
49+
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
50+
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
51+
pl.legend(loc=0)
52+
pl.title('Source and target distributions')
53+
54+
pl.figure(2)
55+
pl.imshow(M, interpolation='nearest')
56+
pl.title('Cost matrix M')
57+
58+
59+
##############################################################################
60+
# Compute Weak OT and exact OT solutions
61+
# --------------------------------------
62+
63+
#%% EMD
64+
65+
G0 = ot.emd(a, b, M)
66+
67+
#%% Weak OT
68+
69+
Gweak = ot.weak_optimal_transport(xs, xt, a, b)
70+
71+
72+
##############################################################################
73+
# Plot weak OT and exact OT solutions
74+
# --------------------------------------
75+
76+
pl.figure(3, (8, 5))
77+
78+
pl.subplot(1, 2, 1)
79+
pl.imshow(G0, interpolation='nearest')
80+
pl.title('OT matrix')
81+
82+
pl.subplot(1, 2, 2)
83+
pl.imshow(Gweak, interpolation='nearest')
84+
pl.title('Weak OT matrix')
85+
86+
pl.figure(4, (8, 5))
87+
88+
pl.subplot(1, 2, 1)
89+
ot.plot.plot2D_samples_mat(xs, xt, G0, c=[.5, .5, 1])
90+
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
91+
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
92+
pl.title('OT matrix with samples')
93+
94+
pl.subplot(1, 2, 2)
95+
ot.plot.plot2D_samples_mat(xs, xt, Gweak, c=[.5, .5, 1])
96+
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
97+
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
98+
pl.title('Weak OT matrix with samples')

examples/plot_OT_2D_samples.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242

4343
# loss matrix
4444
M = ot.dist(xs, xt)
45-
M /= M.max()
4645

4746
##############################################################################
4847
# Plot data
@@ -87,7 +86,7 @@
8786
#%% sinkhorn
8887

8988
# reg term
90-
lambd = 1e-3
89+
lambd = 1e-1
9190

9291
Gs = ot.sinkhorn(a, b, M, lambd)
9392

@@ -112,7 +111,7 @@
112111
#%% sinkhorn
113112

114113
# reg term
115-
lambd = 1e-3
114+
lambd = 1e-1
116115

117116
Ges = ot.bregman.empirical_sinkhorn(xs, xt, lambd)
118117

ot/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from . import partial
3737
from . import backend
3838
from . import regpath
39+
from . import weak
3940

4041
# OT functions
4142
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -46,7 +47,7 @@
4647
from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance
4748
from .gromov import (gromov_wasserstein, gromov_wasserstein2,
4849
gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2)
49-
50+
from .weak import weak_optimal_transport
5051
# utils functions
5152
from .utils import dist, unif, tic, toc, toq
5253

@@ -59,5 +60,5 @@
5960
'sinkhorn_unbalanced', 'barycenter_unbalanced',
6061
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
6162
'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2',
62-
'max_sliced_wasserstein_distance',
63+
'max_sliced_wasserstein_distance', 'weak_optimal_transport',
6364
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']

ot/gromov.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,10 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=F
338338
- :math:`\mathbf{q}`: distribution in the target space
339339
- `L`: loss function to account for the misfit between the similarity matrices
340340
341+
.. note:: This function is backend-compatible and will work on arrays
342+
from all compatible backends. But the algorithm uses the C++ CPU backend
343+
which can lead to copy overhead on GPU arrays.
344+
341345
Parameters
342346
----------
343347
C1 : array-like, shape (ns, ns)
@@ -436,6 +440,10 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=
436440
Note that when using backends, this loss function is differentiable wrt the
437441
marices and weights for quadratic loss using the gradients from [38]_.
438442
443+
.. note:: This function is backend-compatible and will work on arrays
444+
from all compatible backends. But the algorithm uses the C++ CPU backend
445+
which can lead to copy overhead on GPU arrays.
446+
439447
Parameters
440448
----------
441449
C1 : array-like, shape (ns, ns)
@@ -545,6 +553,10 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5,
545553
- :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1)
546554
- `L` is a loss function to account for the misfit between the similarity matrices
547555
556+
.. note:: This function is backend-compatible and will work on arrays
557+
from all compatible backends. But the algorithm uses the C++ CPU backend
558+
which can lead to copy overhead on GPU arrays.
559+
548560
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] <references-fused-gromov-wasserstein>`
549561
550562
Parameters
@@ -645,6 +657,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5
645657
The algorithm used for solving the problem is conditional gradient as
646658
discussed in :ref:`[24] <references-fused-gromov-wasserstein2>`
647659
660+
.. note:: This function is backend-compatible and will work on arrays
661+
from all compatible backends. But the algorithm uses the C++ CPU backend
662+
which can lead to copy overhead on GPU arrays.
663+
648664
Note that when using backends, this loss function is differentiable wrt the
649665
marices and weights for quadratic loss using the gradients from [38]_.
650666

ot/lp/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from ..utils import parmap
2727
from ..backend import get_backend
2828

29+
30+
2931
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
3032
'emd_1d', 'emd2_1d', 'wasserstein_1d']
3133

@@ -220,7 +222,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
220222
format
221223
222224
.. note:: This function is backend-compatible and will work on arrays
223-
from all compatible backends.
225+
from all compatible backends. But the algorithm uses the C++ CPU backend
226+
which can lead to copy overhead on GPU arrays.
224227
225228
Uses the algorithm proposed in :ref:`[1] <references-emd>`.
226229
@@ -358,7 +361,8 @@ def emd2(a, b, M, processes=1,
358361
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are the sample weights
359362
360363
.. note:: This function is backend-compatible and will work on arrays
361-
from all compatible backends.
364+
from all compatible backends. But the algorithm uses the C++ CPU backend
365+
which can lead to copy overhead on GPU arrays.
362366
363367
Uses the algorithm proposed in :ref:`[1] <references-emd2>`.
364368
@@ -622,3 +626,4 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
622626
return X, log_dict
623627
else:
624628
return X
629+

ot/lp/cvx.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import scipy as sp
1212
import scipy.sparse as sps
1313

14-
1514
try:
1615
import cvxopt
1716
from cvxopt import solvers, matrix, spmatrix

ot/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,21 +116,27 @@ def proj_simplex(v, z=1):
116116
return w
117117

118118

119-
def unif(n):
119+
def unif(n, type_as=None):
120120
r"""
121121
Return a uniform histogram of length `n` (simplex).
122122
123123
Parameters
124124
----------
125125
n : int
126126
number of bins in the histogram
127+
type_as : array_like
128+
array of the same type of the expected output (numpy/pytorch/jax)
127129
128130
Returns
129131
-------
130-
h : np.array (`n`,)
132+
h : array_like (`n`,)
131133
histogram of length `n` such that :math:`\forall i, \mathbf{h}_i = \frac{1}{n}`
132134
"""
133-
return np.ones((n,)) / n
135+
if type_as is None:
136+
return np.ones((n,)) / n
137+
else:
138+
nx = get_backend(type_as)
139+
return nx.ones((n,)) / n
134140

135141

136142
def clean_zeros(a, b, M):

0 commit comments

Comments
 (0)