Skip to content

Commit 7c5c880

Browse files
authored
Merge pull request #56 from vivienseguy/vivien-barycenters
Free support barycenters
2 parents 39cbcd3 + af57d90 commit 7c5c880

File tree

5 files changed

+183
-2
lines changed

5 files changed

+183
-2
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ It provides the following solvers:
1717
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat).
1818
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
1919
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
20+
* Non regularized free support Wasserstein barycenters [20].
2021
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
2122
* Optimal transport for domain adaptation with group lasso regularization [5]
2223
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
@@ -225,3 +226,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
225226
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](arXiv preprint arxiv:1605.08527). Advances in Neural Information Processing Systems (2016).
226227

227228
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)
229+
230+
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
====================================================
4+
2D free support Wasserstein barycenters of distributions
5+
====================================================
6+
7+
Illustration of 2D Wasserstein barycenters if discributions that are weighted
8+
sum of diracs.
9+
10+
"""
11+
12+
# Author: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
13+
#
14+
# License: MIT License
15+
16+
import numpy as np
17+
import matplotlib.pylab as pl
18+
import ot
19+
20+
21+
##############################################################################
22+
# Generate data
23+
# -------------
24+
#%% parameters and data generation
25+
N = 3
26+
d = 2
27+
measures_locations = []
28+
measures_weights = []
29+
30+
for i in range(N):
31+
32+
n_i = np.random.randint(low=1, high=20) # nb samples
33+
34+
mu_i = np.random.normal(0., 4., (d,)) # Gaussian mean
35+
36+
A_i = np.random.rand(d, d)
37+
cov_i = np.dot(A_i, A_i.transpose()) # Gaussian covariance matrix
38+
39+
x_i = ot.datasets.make_2D_samples_gauss(n_i, mu_i, cov_i) # Dirac locations
40+
b_i = np.random.uniform(0., 1., (n_i,))
41+
b_i = b_i / np.sum(b_i) # Dirac weights
42+
43+
measures_locations.append(x_i)
44+
measures_weights.append(b_i)
45+
46+
47+
##############################################################################
48+
# Compute free support barycenter
49+
# -------------
50+
51+
k = 10 # number of Diracs of the barycenter
52+
X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
53+
b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)
54+
55+
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)
56+
57+
58+
##############################################################################
59+
# Plot data
60+
# ---------
61+
62+
pl.figure(1)
63+
for (x_i, b_i) in zip(measures_locations, measures_weights):
64+
color = np.random.randint(low=1, high=10 * N)
65+
pl.scatter(x_i[:, 0], x_i[:, 1], s=b * 1000, label='input measure')
66+
pl.scatter(X[:, 0], X[:, 1], s=b * 1000, c='black', marker='^', label='2-Wasserstein barycenter')
67+
pl.title('Data measures and their barycenter')
68+
pl.legend(loc=0)
69+
pl.show()

ot/lp/__init__.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
from .emd_wrap import emd_c, check_result
1818
from ..utils import parmap
1919
from .cvx import barycenter
20+
from ..utils import dist
2021

21-
__all__=['emd', 'emd2', 'barycenter', 'cvx']
22+
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx']
2223

2324

2425
def emd(a, b, M, numItermax=100000, log=False):
@@ -216,3 +217,95 @@ def f(b):
216217

217218
res = parmap(f, [b[:, i] for i in range(nb)], processes)
218219
return res
220+
221+
222+
223+
def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
224+
"""
225+
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
226+
227+
The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
228+
This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
229+
- we do not optimize over the weights
230+
- we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
231+
232+
Parameters
233+
----------
234+
measures_locations : list of (k_i,d) np.ndarray
235+
The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
236+
measures_weights : list of (k_i,) np.ndarray
237+
Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
238+
239+
X_init : (k,d) np.ndarray
240+
Initialization of the support locations (on k atoms) of the barycenter
241+
b : (k,) np.ndarray
242+
Initialization of the weights of the barycenter (non-negatives, sum to 1)
243+
weights : (k,) np.ndarray
244+
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
245+
246+
numItermax : int, optional
247+
Max number of iterations
248+
stopThr : float, optional
249+
Stop threshol on error (>0)
250+
verbose : bool, optional
251+
Print information along iterations
252+
log : bool, optional
253+
record log if True
254+
255+
Returns
256+
-------
257+
X : (k,d) np.ndarray
258+
Support locations (on k atoms) of the barycenter
259+
260+
References
261+
----------
262+
263+
.. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
264+
265+
.. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
266+
267+
"""
268+
269+
iter_count = 0
270+
271+
N = len(measures_locations)
272+
k = X_init.shape[0]
273+
d = X_init.shape[1]
274+
if b is None:
275+
b = np.ones((k,))/k
276+
if weights is None:
277+
weights = np.ones((N,)) / N
278+
279+
X = X_init
280+
281+
log_dict = {}
282+
displacement_square_norms = []
283+
284+
displacement_square_norm = stopThr + 1.
285+
286+
while ( displacement_square_norm > stopThr and iter_count < numItermax ):
287+
288+
T_sum = np.zeros((k, d))
289+
290+
for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights.tolist()):
291+
292+
M_i = dist(X, measure_locations_i)
293+
T_i = emd(b, measure_weights_i, M_i)
294+
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)
295+
296+
displacement_square_norm = np.sum(np.square(T_sum-X))
297+
if log:
298+
displacement_square_norms.append(displacement_square_norm)
299+
300+
X = T_sum
301+
302+
if verbose:
303+
print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)
304+
305+
iter_count += 1
306+
307+
if log:
308+
log_dict['displacement_square_norms'] = displacement_square_norms
309+
return X, log_dict
310+
else:
311+
return X

ot/lp/cvx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import scipy as sp
1212
import scipy.sparse as sps
1313

14+
1415
try:
1516
import cvxopt
1617
from cvxopt import solvers, matrix, spmatrix
@@ -26,7 +27,7 @@ def scipy_sparse_to_spmatrix(A):
2627

2728

2829
def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
29-
"""Compute the entropic regularized wasserstein barycenter of distributions A
30+
"""Compute the Wasserstein barycenter of distributions A
3031
3132
The function solves the following optimization problem [16]:
3233

test/test_ot.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,21 @@ def test_lp_barycenter():
135135
np.testing.assert_allclose(bary.sum(), 1)
136136

137137

138+
def test_free_support_barycenter():
139+
140+
measures_locations = [np.array([-1.]).reshape((1, 1)), np.array([1.]).reshape((1, 1))]
141+
measures_weights = [np.array([1.]), np.array([1.])]
142+
143+
X_init = np.array([-12.]).reshape((1, 1))
144+
145+
# obvious barycenter location between two diracs
146+
bar_locations = np.array([0.]).reshape((1, 1))
147+
148+
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init)
149+
150+
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)
151+
152+
138153
@pytest.mark.skipif(not ot.lp.cvx.cvxopt, reason="No cvxopt available")
139154
def test_lp_barycenter_cvxopt():
140155

0 commit comments

Comments
 (0)