Skip to content

[FEAT] Add EWCA in to ot.dr #486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,5 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR).

[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019.

[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). [Entropic Wasserstein Component Analysis](https://arxiv.org/abs/2303.05119). ArXiv.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- Added features `warmstartT` and `kwargs` to all CG and entropic (F)GW barycenter solvers (PR #455)
- Added entropic semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #455)
- Make marginal parameters optional for (F)GW solvers in `._gw`, `._bregman` and `._semirelaxed` (PR #455)
- Add Entropic Wasserstein Component Analysis (ECWA) in ot.dr (PR #486)

#### Closed issues

Expand Down
187 changes: 187 additions & 0 deletions examples/others/plot_EWCA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# -*- coding: utf-8 -*-
"""
=======================================
Entropic Wasserstein Component Analysis
=======================================

This example illustrates the use of EWCA as proposed in [52].


[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023).
Entropic Wasserstein Component Analysis.

"""

# Author: Antoine Collas <antoine.collas@inria.fr>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 2

import numpy as np
import matplotlib.pylab as pl
from ot.dr import ewca
from sklearn.datasets import make_blobs
from matplotlib import ticker as mticker
import matplotlib.patches as patches
import matplotlib

##############################################################################
# Generate data
# -------------

n_samples = 20
esp = 0.8
centers = np.array([[esp, esp], [-esp, -esp]])
cluster_std = 0.4

rng = np.random.RandomState(42)
X, y = make_blobs(
n_samples=n_samples,
n_features=2,
centers=centers,
cluster_std=cluster_std,
shuffle=False,
random_state=rng,
)
X = X - X.mean(0)

##############################################################################
# Plot data
# -------------

fig = pl.figure(figsize=(4, 4))
cmap = matplotlib.colormaps.get_cmap("tab10")
pl.scatter(
X[: n_samples // 2, 0],
X[: n_samples // 2, 1],
color=[cmap(y[i] + 1) for i in range(n_samples // 2)],
alpha=0.4,
label="Class 1",
zorder=30,
s=50,
)
pl.scatter(
X[n_samples // 2:, 0],
X[n_samples // 2:, 1],
color=[cmap(y[i] + 1) for i in range(n_samples // 2, n_samples)],
alpha=0.4,
label="Class 2",
zorder=30,
s=50,
)
x_y_lim = 2.5
fs = 15
pl.xlim(-x_y_lim, x_y_lim)
pl.xticks([])
pl.ylim(-x_y_lim, x_y_lim)
pl.yticks([])
pl.legend(fontsize=fs)
pl.title("Data", fontsize=fs)
pl.tight_layout()


##############################################################################
# Compute EWCA
# -------------

pi, U = ewca(X, k=2, reg=0.5)


##############################################################################
# Plot data, first component, and projected data
# -------------

fig = pl.figure(figsize=(4, 4))

scale = 3
u = U[:, 0]
pl.plot(
[scale * u[0], -scale * u[0]],
[scale * u[1], -scale * u[1]],
color="grey",
linestyle="--",
lw=3,
alpha=0.3,
label=r"$\mathbf{U}$",
)
X1 = X @ u[:, None] @ u[:, None].T

for i in range(n_samples):
for j in range(n_samples):
v = pi[i, j] / pi.max()
if v >= 0.15 or (i, j) == (n_samples - 1, n_samples - 1):
pl.plot(
[X[i, 0], X1[j, 0]],
[X[i, 1], X1[j, 1]],
alpha=v,
linestyle="-",
c="C0",
label=r"$\pi_{ij}$"
if (i, j) == (n_samples - 1, n_samples - 1)
else None,
)
pl.scatter(
X[:, 0],
X[:, 1],
color=[cmap(y[i] + 1) for i in range(n_samples)],
alpha=0.4,
label=r"$\mathbf{x}_i$",
zorder=30,
s=50,
)
pl.scatter(
X1[:, 0],
X1[:, 1],
color=[cmap(y[i] + 1) for i in range(n_samples)],
alpha=0.9,
s=50,
marker="+",
label=r"$\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_i$",
zorder=30,
)
pl.title("Data and projections", fontsize=fs)
pl.xlim(-x_y_lim, x_y_lim)
pl.xticks([])
pl.ylim(-x_y_lim, x_y_lim)
pl.yticks([])
pl.legend(fontsize=fs, loc="upper left")
pl.tight_layout()


##############################################################################
# Plot transport plan
# -------------

fig = pl.figure(figsize=(5, 5))

norm = matplotlib.colors.PowerNorm(0.5, vmin=0, vmax=100)
im = pl.imshow(n_samples * pi * 100, cmap=pl.cm.Blues, norm=norm, aspect="auto")
cb = fig.colorbar(im, orientation="vertical", shrink=0.8)
ticks_loc = cb.ax.get_yticks().tolist()
cb.ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
cb.ax.set_yticklabels([f"{int(i)}%" for i in cb.get_ticks()])
cb.ax.tick_params(labelsize=fs)
for i, class_ in enumerate(np.sort(np.unique(y))):
indices = y == class_
idx_min = np.min(np.arange(len(y))[indices])
idx_max = np.max(np.arange(len(y))[indices])
width = idx_max - idx_min + 1
rect = patches.Rectangle(
(idx_min - 0.5, idx_min - 0.5),
width,
width,
linewidth=1,
edgecolor="r",
facecolor="none",
)
pl.gca().add_patch(rect)

pl.title("OT plan", fontsize=fs)
pl.ylabel(r"($\mathbf{x}_1, \cdots, \mathbf{x}_n$)")
x_label = r"($\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_1, \cdots,"
x_label += r"\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_n$)"
pl.xlabel(x_label)
pl.tight_layout()
pl.axis("scaled")

pl.show()
155 changes: 155 additions & 0 deletions ot/dr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,21 @@
# Author: Remi Flamary <remi.flamary@unice.fr>
# Minhui Huang <mhhuang@ucdavis.edu>
# Jakub Zadrozny <jakub.r.zadrozny@gmail.com>
# Antoine Collas <antoine.collas@inria.fr>
#
# License: MIT License

from scipy import linalg
import autograd.numpy as np
from sklearn.decomposition import PCA

import pymanopt
import pymanopt.manifolds
import pymanopt.optimizers

from .bregman import sinkhorn as sinkhorn_bregman
from .utils import dist as dist_utils


def dist(x1, x2):
r""" Compute squared euclidean distance between samples (autograd)
Expand Down Expand Up @@ -376,3 +381,153 @@ def Vpi(X, Y, a, b, pi):
iter = iter + 1

return pi, U


def ewca(X, U0=None, reg=1, k=2, method='BCD', sinkhorn_method='sinkhorn', stopThr=1e-6, maxiter=100, maxiter_sink=1000, maxiter_MM=10, verbose=0):
r"""
Entropic Wasserstein Component Analysis :ref:`[52] <references-entropic-wasserstein-component_analysis>`.

The function solves the following optimization problem:

.. math::
\mathbf{U} = \mathop{\arg \min}_\mathbf{U} \quad
W(\mathbf{X}, \mathbf{U}\mathbf{U}^T \mathbf{X})

where :

- :math:`\mathbf{U}` is a matrix in the Stiefel(`p`, `d`) manifold
- :math:`W` is entropic regularized Wasserstein distances
- :math:`\mathbf{X}` are samples

Parameters
----------
X : ndarray, shape (n, d)
Samples from measure :math:`\mu`.
U0 : ndarray, shape (d, k), optional
Initial starting point for projection.
reg : float, optional
Regularization term >0 (entropic regularization).
k : int, optional
Subspace dimension.
method : str, optional
Eather 'BCD' or 'MM' (Block Coordinate Descent or Majorization-Minimization).
Prefer MM when d is large.
sinkhorn_method : str
Method used for the Sinkhorn solver, see :ref:`ot.bregman.sinkhorn` for more details.
stopThr : float, optional
Stop threshold on error (>0).
maxiter : int, optional
Maximum number of iterations of the BCD/MM.
maxiter_sink : int, optional
Maximum number of iterations of the Sinkhorn solver.
maxiter_MM : int, optional
Maximum number of iterations of the MM (only used when method='MM').
verbose : int, optional
Print information along iterations.

Returns
-------
pi : ndarray, shape (n, n)
Optimal transportation matrix for the given parameters.
U : ndarray, shape (d, k)
Matrix Stiefel manifold.


.. _references-entropic-wasserstein-component_analysis:
References
----------
.. [52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023).
Entropic Wasserstein Component Analysis.
""" # noqa
n, d = X.shape
X = X - X.mean(0)

if U0 is None:
pca_fitted = PCA(n_components=k).fit(X)
U = pca_fitted.components_.T
if method == 'MM':
lambda_scm = pca_fitted.explained_variance_[0]
else:
U = U0

# marginals
u0 = (1. / n) * np.ones(n)

# print iterations
if verbose > 0:
print('{:4s}|{:13s}|{:12s}|{:12s}'.format('It.', 'Loss', 'Crit.', 'Thres.') + '\n' + '-' * 40)

def compute_loss(M, pi, reg):
return np.sum(M * pi) + reg * np.sum(pi * (np.log(pi) - 1))

def grassmann_distance(U1, U2):
proj = U1.T @ U2
_, s, _ = np.linalg.svd(proj)
s[s > 1] = 1
s = np.arccos(s)
return np.linalg.norm(s)

# loop
it = 0
crit = np.inf
sinkhorn_warmstart = None

while (it < maxiter) and (crit > stopThr):
U_old = U

# Solve transport
M = dist_utils(X, (X @ U) @ U.T)
pi, log_sinkhorn = sinkhorn_bregman(
u0, u0, M, reg,
numItermax=maxiter_sink,
method=sinkhorn_method, warmstart=sinkhorn_warmstart,
warn=False, log=True
)
key_warmstart = 'warmstart'
if key_warmstart in log_sinkhorn:
sinkhorn_warmstart = log_sinkhorn[key_warmstart]
if (pi >= 1e-300).all():
loss = compute_loss(M, pi, reg)
else:
loss = np.inf

# Solve PCA
pi_sym = (pi + pi.T) / 2

if method == 'BCD':
# block coordinate descent
S = X.T @ (2 * pi_sym - (1. / n) * np.eye(n)) @ X
_, U = np.linalg.eigh(S)
U = U[:, ::-1][:, :k]

elif method == 'MM':
# majorization-minimization
eig, _ = np.linalg.eigh(pi_sym)
lambda_pi = eig[0]

for _ in range(maxiter_MM):
X_proj = X @ U
X_T_X_proj = X.T @ X_proj

R = (1 / n) * X_T_X_proj
alpha = 1 - 2 * n * lambda_pi
if alpha > 0:
R = alpha * (R - lambda_scm * U)
else:
R = alpha * R

R = R - (2 * X.T @ (pi_sym @ X_proj)) + (2 * lambda_pi * X_T_X_proj)
U, _ = np.linalg.qr(R)

else:
raise ValueError(f"Unknown method '{method}', use 'BCD' or 'MM'.")

# stop or not
it += 1
crit = grassmann_distance(U_old, U)

# print
if verbose > 0:
print('{:4d}|{:8e}|{:8e}|{:8e}'.format(it, loss, crit, stopThr))

return pi, U
Loading