Skip to content

[MRG] Free support Sinkhorn barycenters #387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ The contributors to this library are:
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)

## Acknowledgments

Expand Down
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#### New features

- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
- Added Free Support Sinkhorn Barycenter + example (PR #387)

#### Closed issues

Expand Down
28 changes: 25 additions & 3 deletions examples/barycenters/plot_free_support_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
2D free support Wasserstein barycenters of distributions
========================================================

Illustration of 2D Wasserstein barycenters if distributions are weighted
Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted
sum of diracs.

"""

# Authors: Vivien Seguy <vivien.seguy@iip.ist.i.kyoto-u.ac.jp>
# Rémi Flamary <remi.flamary@polytechnique.edu>
# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License

Expand Down Expand Up @@ -48,7 +49,7 @@


# %%
# Compute free support barycenter
# Compute free support Wasserstein barycenter
# -------------------------------

k = 200 # number of Diracs of the barycenter
Expand All @@ -58,7 +59,28 @@
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)

# %%
# Plot the barycenter
# Plot the Wasserstein barycenter
# ---------

pl.figure(2, (8, 3))
pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
pl.title('Data measures and their barycenter')
pl.legend(loc="lower right")
pl.show()

# %%
# Compute free support Sinkhorn barycenter

k = 200 # number of Diracs of the barycenter
X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)

X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15)

# %%
# Plot the Wasserstein barycenter
# ---------

pl.figure(2, (8, 3))
Expand Down
151 changes: 151 additions & 0 deletions examples/barycenters/plot_free_support_sinkhorn_barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# -*- coding: utf-8 -*-
"""
========================================================
2D free support Sinkhorn barycenters of distributions
========================================================

Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds

"""

# Authors: Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License

import numpy as np
import matplotlib.pyplot as plt
import ot

# %%
# General Parameters
# ------------------
reg = 1e-2 # Entropic Regularization
numItermax = 20 # Maximum number of iterations for the Barycenter algorithm
numInnerItermax = 50 # Maximum number of sinkhorn iterations
n_samples = 200

# %%
# Generate Data
# -------------

X1 = np.random.randn(200, 2)
X2 = 2 * np.concatenate([
np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1),
np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1),
np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1),
np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1),
], axis=0)
X3 = np.random.randn(200, 2)
X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None])
X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200)

a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1))

# %%
# Inspect generated distributions
# -------------------------------

fig, axes = plt.subplots(1, 4, figsize=(16, 4))

axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k')
axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k')
axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k')
axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k')

axes[0].set_xlim([-3, 3])
axes[0].set_ylim([-3, 3])
axes[0].set_title('Distribution 1')

axes[1].set_xlim([-3, 3])
axes[1].set_ylim([-3, 3])
axes[1].set_title('Distribution 2')

axes[2].set_xlim([-3, 3])
axes[2].set_ylim([-3, 3])
axes[2].set_title('Distribution 3')

axes[3].set_xlim([-3, 3])
axes[3].set_ylim([-3, 3])
axes[3].set_title('Distribution 4')

plt.tight_layout()
plt.show()

# %%
# Interpolating Empirical Distributions
# -------------------------------------

fig = plt.figure(figsize=(10, 10))

weights = np.array([
[3 / 3, 0 / 3],
[2 / 3, 1 / 3],
[1 / 3, 2 / 3],
[0 / 3, 3 / 3],
]).astype(np.float32)

for k in range(4):
XB_init = np.random.randn(n_samples, 2)
XB = ot.bregman.free_support_sinkhorn_barycenter(
measures_locations=[X1, X2],
measures_weights=[a1, a2],
weights=weights[k],
X_init=XB_init,
reg=reg,
numItermax=numItermax,
numInnerItermax=numInnerItermax
)
ax = plt.subplot2grid((4, 4), (0, k))
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
ax.set_xlim([-3, 3])
ax.set_ylim([-3, 3])

for k in range(1, 4, 1):
XB_init = np.random.randn(n_samples, 2)
XB = ot.bregman.free_support_sinkhorn_barycenter(
measures_locations=[X1, X3],
measures_weights=[a1, a2],
weights=weights[k],
X_init=XB_init,
reg=reg,
numItermax=numItermax,
numInnerItermax=numInnerItermax
)
ax = plt.subplot2grid((4, 4), (k, 0))
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
ax.set_xlim([-3, 3])
ax.set_ylim([-3, 3])

for k in range(1, 4, 1):
XB_init = np.random.randn(n_samples, 2)
XB = ot.bregman.free_support_sinkhorn_barycenter(
measures_locations=[X3, X4],
measures_weights=[a1, a2],
weights=weights[k],
X_init=XB_init,
reg=reg,
numItermax=numItermax,
numInnerItermax=numInnerItermax
)
ax = plt.subplot2grid((4, 4), (3, k))
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
ax.set_xlim([-3, 3])
ax.set_ylim([-3, 3])

for k in range(1, 3, 1):
XB_init = np.random.randn(n_samples, 2)
XB = ot.bregman.free_support_sinkhorn_barycenter(
measures_locations=[X2, X4],
measures_weights=[a1, a2],
weights=weights[k],
X_init=XB_init,
reg=reg,
numItermax=numItermax,
numInnerItermax=numInnerItermax
)
ax = plt.subplot2grid((4, 4), (k, 3))
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
ax.set_xlim([-3, 3])
ax.set_ylim([-3, 3])

plt.show()
120 changes: 120 additions & 0 deletions ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,126 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)


def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None,
numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None,
**kwargs):
r"""
Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally:

.. math::
\min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)

where :

- :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
- `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
- `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
- :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter

This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
There are two differences with the following codes:

- we do not optimize over the weights
- we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
:ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
implementation of the fixed-point algorithm of
:ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
- at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the
transport plan in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).

Parameters
----------
measures_locations : list of N (k_i,d) array-like
The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
(:math:`k_i` can be different for each element of the list)
measures_weights : list of N (k_i,) array-like
Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
representing the weights of each discrete input measure

X_init : (k,d) array-like
Initialization of the support locations (on `k` atoms) of the barycenter
reg : float
Regularization term >0
b : (k,) array-like
Initialization of the weights of the barycenter (non-negatives, sum to 1)
weights : (N,) array-like
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)

numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
Max number of iterations when calculating the transport plans with Sinkhorn
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True

Returns
-------
X : (k,d) array-like
Support locations (on k atoms) of the barycenter

See Also
--------
ot.bregman.sinkhorn : Entropic regularized OT solver
ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming

.. _references-free-support-barycenter:
References
----------
.. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.

.. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.

"""
nx = get_backend(*measures_locations, *measures_weights, X_init)

iter_count = 0

N = len(measures_locations)
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
b = nx.ones((k,), type_as=X_init) / k
if weights is None:
weights = nx.ones((N,), type_as=X_init) / N

X = X_init

log_dict = {}
displacement_square_norms = []

displacement_square_norm = stopThr + 1.

while (displacement_square_norm > stopThr and iter_count < numItermax):

T_sum = nx.zeros((k, d), type_as=X_init)

for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
M_i = dist(X, measure_locations_i)
T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs)
T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i)

displacement_square_norm = nx.sum((T_sum - X) ** 2)
if log:
displacement_square_norms.append(displacement_square_norm)

X = T_sum

if verbose:
print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)

iter_count += 1

if log:
log_dict['displacement_square_norms'] = displacement_square_norms
return X, log_dict
else:
return X


def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False, warn=True):
r"""Compute the entropic wasserstein barycenter in log-domain
Expand Down
26 changes: 26 additions & 0 deletions test/test_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Author: Remi Flamary <remi.flamary@unice.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
# Eduardo Fernandes Montesuma <eduardo.fernandes-montesuma@universite-paris-saclay.fr>
#
# License: MIT License

Expand Down Expand Up @@ -490,6 +491,31 @@ def test_barycenter(nx, method, verbose, warn):
ot.bregman.barycenter(A_nx, M_nx, reg, log=True)


def test_free_support_sinkhorn_barycenter():
measures_locations = [
np.array([-1.]).reshape((1, 1)), # First dirac support
np.array([1.]).reshape((1, 1)) # Second dirac support
]

measures_weights = [
np.array([1.]), # First dirac sample weights
np.array([1.]) # Second dirac sample weights
]

# Barycenter initialization
X_init = np.array([-12.]).reshape((1, 1))

# Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter
bar_locations = np.array([0.]).reshape((1, 1))

# Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization
# term to 1, but this should be, in general, fine-tuned to the problem.
X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1)

# Verifies if calculated barycenter matches ground-truth
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)


@pytest.mark.parametrize("method, verbose, warn",
product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
[True, False], [True, False]))
Expand Down