Skip to content

Commit b11f5a0

Browse files
[FEAT] Add EWCA in to ot.dr (#486)
* WIP: add EWCA in dr module * make ewca consistent with other methods from ot.dr * add EWCA paper in README * add test of ot.dr.ewca * change stopThr and maxiter_sink of ewca * improve test of ewca by checking estimation of first and last principle components * improve docstring of EWCA * rm double import of ot.dr * fix grassmann_distance function name * add EWCA to RELEASES.md * add example of EWCA * change title of EWCA example * improve EWCA example --------- Co-authored-by: Antoine Collas <22830806+antoinecollas@users.noreply.github.com> Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 80461cd commit b11f5a0

File tree

5 files changed

+401
-0
lines changed

5 files changed

+401
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,5 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
312312
[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR).
313313

314314
[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019.
315+
316+
[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). [Entropic Wasserstein Component Analysis](https://arxiv.org/abs/2303.05119). ArXiv.

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
- Added features `warmstartT` and `kwargs` to all CG and entropic (F)GW barycenter solvers (PR #455)
1616
- Added entropic semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #455)
1717
- Make marginal parameters optional for (F)GW solvers in `._gw`, `._bregman` and `._semirelaxed` (PR #455)
18+
- Add Entropic Wasserstein Component Analysis (ECWA) in ot.dr (PR #486)
1819

1920
#### Closed issues
2021

examples/others/plot_EWCA.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
=======================================
4+
Entropic Wasserstein Component Analysis
5+
=======================================
6+
7+
This example illustrates the use of EWCA as proposed in [52].
8+
9+
10+
[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023).
11+
Entropic Wasserstein Component Analysis.
12+
13+
"""
14+
15+
# Author: Antoine Collas <antoine.collas@inria.fr>
16+
#
17+
# License: MIT License
18+
# sphinx_gallery_thumbnail_number = 2
19+
20+
import numpy as np
21+
import matplotlib.pylab as pl
22+
from ot.dr import ewca
23+
from sklearn.datasets import make_blobs
24+
from matplotlib import ticker as mticker
25+
import matplotlib.patches as patches
26+
import matplotlib
27+
28+
##############################################################################
29+
# Generate data
30+
# -------------
31+
32+
n_samples = 20
33+
esp = 0.8
34+
centers = np.array([[esp, esp], [-esp, -esp]])
35+
cluster_std = 0.4
36+
37+
rng = np.random.RandomState(42)
38+
X, y = make_blobs(
39+
n_samples=n_samples,
40+
n_features=2,
41+
centers=centers,
42+
cluster_std=cluster_std,
43+
shuffle=False,
44+
random_state=rng,
45+
)
46+
X = X - X.mean(0)
47+
48+
##############################################################################
49+
# Plot data
50+
# -------------
51+
52+
fig = pl.figure(figsize=(4, 4))
53+
cmap = matplotlib.colormaps.get_cmap("tab10")
54+
pl.scatter(
55+
X[: n_samples // 2, 0],
56+
X[: n_samples // 2, 1],
57+
color=[cmap(y[i] + 1) for i in range(n_samples // 2)],
58+
alpha=0.4,
59+
label="Class 1",
60+
zorder=30,
61+
s=50,
62+
)
63+
pl.scatter(
64+
X[n_samples // 2:, 0],
65+
X[n_samples // 2:, 1],
66+
color=[cmap(y[i] + 1) for i in range(n_samples // 2, n_samples)],
67+
alpha=0.4,
68+
label="Class 2",
69+
zorder=30,
70+
s=50,
71+
)
72+
x_y_lim = 2.5
73+
fs = 15
74+
pl.xlim(-x_y_lim, x_y_lim)
75+
pl.xticks([])
76+
pl.ylim(-x_y_lim, x_y_lim)
77+
pl.yticks([])
78+
pl.legend(fontsize=fs)
79+
pl.title("Data", fontsize=fs)
80+
pl.tight_layout()
81+
82+
83+
##############################################################################
84+
# Compute EWCA
85+
# -------------
86+
87+
pi, U = ewca(X, k=2, reg=0.5)
88+
89+
90+
##############################################################################
91+
# Plot data, first component, and projected data
92+
# -------------
93+
94+
fig = pl.figure(figsize=(4, 4))
95+
96+
scale = 3
97+
u = U[:, 0]
98+
pl.plot(
99+
[scale * u[0], -scale * u[0]],
100+
[scale * u[1], -scale * u[1]],
101+
color="grey",
102+
linestyle="--",
103+
lw=3,
104+
alpha=0.3,
105+
label=r"$\mathbf{U}$",
106+
)
107+
X1 = X @ u[:, None] @ u[:, None].T
108+
109+
for i in range(n_samples):
110+
for j in range(n_samples):
111+
v = pi[i, j] / pi.max()
112+
if v >= 0.15 or (i, j) == (n_samples - 1, n_samples - 1):
113+
pl.plot(
114+
[X[i, 0], X1[j, 0]],
115+
[X[i, 1], X1[j, 1]],
116+
alpha=v,
117+
linestyle="-",
118+
c="C0",
119+
label=r"$\pi_{ij}$"
120+
if (i, j) == (n_samples - 1, n_samples - 1)
121+
else None,
122+
)
123+
pl.scatter(
124+
X[:, 0],
125+
X[:, 1],
126+
color=[cmap(y[i] + 1) for i in range(n_samples)],
127+
alpha=0.4,
128+
label=r"$\mathbf{x}_i$",
129+
zorder=30,
130+
s=50,
131+
)
132+
pl.scatter(
133+
X1[:, 0],
134+
X1[:, 1],
135+
color=[cmap(y[i] + 1) for i in range(n_samples)],
136+
alpha=0.9,
137+
s=50,
138+
marker="+",
139+
label=r"$\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_i$",
140+
zorder=30,
141+
)
142+
pl.title("Data and projections", fontsize=fs)
143+
pl.xlim(-x_y_lim, x_y_lim)
144+
pl.xticks([])
145+
pl.ylim(-x_y_lim, x_y_lim)
146+
pl.yticks([])
147+
pl.legend(fontsize=fs, loc="upper left")
148+
pl.tight_layout()
149+
150+
151+
##############################################################################
152+
# Plot transport plan
153+
# -------------
154+
155+
fig = pl.figure(figsize=(5, 5))
156+
157+
norm = matplotlib.colors.PowerNorm(0.5, vmin=0, vmax=100)
158+
im = pl.imshow(n_samples * pi * 100, cmap=pl.cm.Blues, norm=norm, aspect="auto")
159+
cb = fig.colorbar(im, orientation="vertical", shrink=0.8)
160+
ticks_loc = cb.ax.get_yticks().tolist()
161+
cb.ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
162+
cb.ax.set_yticklabels([f"{int(i)}%" for i in cb.get_ticks()])
163+
cb.ax.tick_params(labelsize=fs)
164+
for i, class_ in enumerate(np.sort(np.unique(y))):
165+
indices = y == class_
166+
idx_min = np.min(np.arange(len(y))[indices])
167+
idx_max = np.max(np.arange(len(y))[indices])
168+
width = idx_max - idx_min + 1
169+
rect = patches.Rectangle(
170+
(idx_min - 0.5, idx_min - 0.5),
171+
width,
172+
width,
173+
linewidth=1,
174+
edgecolor="r",
175+
facecolor="none",
176+
)
177+
pl.gca().add_patch(rect)
178+
179+
pl.title("OT plan", fontsize=fs)
180+
pl.ylabel(r"($\mathbf{x}_1, \cdots, \mathbf{x}_n$)")
181+
x_label = r"($\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_1, \cdots,"
182+
x_label += r"\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_n$)"
183+
pl.xlabel(x_label)
184+
pl.tight_layout()
185+
pl.axis("scaled")
186+
187+
pl.show()

ot/dr.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,21 @@
1212
# Author: Remi Flamary <remi.flamary@unice.fr>
1313
# Minhui Huang <mhhuang@ucdavis.edu>
1414
# Jakub Zadrozny <jakub.r.zadrozny@gmail.com>
15+
# Antoine Collas <antoine.collas@inria.fr>
1516
#
1617
# License: MIT License
1718

1819
from scipy import linalg
1920
import autograd.numpy as np
21+
from sklearn.decomposition import PCA
2022

2123
import pymanopt
2224
import pymanopt.manifolds
2325
import pymanopt.optimizers
2426

27+
from .bregman import sinkhorn as sinkhorn_bregman
28+
from .utils import dist as dist_utils
29+
2530

2631
def dist(x1, x2):
2732
r""" Compute squared euclidean distance between samples (autograd)
@@ -376,3 +381,153 @@ def Vpi(X, Y, a, b, pi):
376381
iter = iter + 1
377382

378383
return pi, U
384+
385+
386+
def ewca(X, U0=None, reg=1, k=2, method='BCD', sinkhorn_method='sinkhorn', stopThr=1e-6, maxiter=100, maxiter_sink=1000, maxiter_MM=10, verbose=0):
387+
r"""
388+
Entropic Wasserstein Component Analysis :ref:`[52] <references-entropic-wasserstein-component_analysis>`.
389+
390+
The function solves the following optimization problem:
391+
392+
.. math::
393+
\mathbf{U} = \mathop{\arg \min}_\mathbf{U} \quad
394+
W(\mathbf{X}, \mathbf{U}\mathbf{U}^T \mathbf{X})
395+
396+
where :
397+
398+
- :math:`\mathbf{U}` is a matrix in the Stiefel(`p`, `d`) manifold
399+
- :math:`W` is entropic regularized Wasserstein distances
400+
- :math:`\mathbf{X}` are samples
401+
402+
Parameters
403+
----------
404+
X : ndarray, shape (n, d)
405+
Samples from measure :math:`\mu`.
406+
U0 : ndarray, shape (d, k), optional
407+
Initial starting point for projection.
408+
reg : float, optional
409+
Regularization term >0 (entropic regularization).
410+
k : int, optional
411+
Subspace dimension.
412+
method : str, optional
413+
Eather 'BCD' or 'MM' (Block Coordinate Descent or Majorization-Minimization).
414+
Prefer MM when d is large.
415+
sinkhorn_method : str
416+
Method used for the Sinkhorn solver, see :ref:`ot.bregman.sinkhorn` for more details.
417+
stopThr : float, optional
418+
Stop threshold on error (>0).
419+
maxiter : int, optional
420+
Maximum number of iterations of the BCD/MM.
421+
maxiter_sink : int, optional
422+
Maximum number of iterations of the Sinkhorn solver.
423+
maxiter_MM : int, optional
424+
Maximum number of iterations of the MM (only used when method='MM').
425+
verbose : int, optional
426+
Print information along iterations.
427+
428+
Returns
429+
-------
430+
pi : ndarray, shape (n, n)
431+
Optimal transportation matrix for the given parameters.
432+
U : ndarray, shape (d, k)
433+
Matrix Stiefel manifold.
434+
435+
436+
.. _references-entropic-wasserstein-component_analysis:
437+
References
438+
----------
439+
.. [52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023).
440+
Entropic Wasserstein Component Analysis.
441+
""" # noqa
442+
n, d = X.shape
443+
X = X - X.mean(0)
444+
445+
if U0 is None:
446+
pca_fitted = PCA(n_components=k).fit(X)
447+
U = pca_fitted.components_.T
448+
if method == 'MM':
449+
lambda_scm = pca_fitted.explained_variance_[0]
450+
else:
451+
U = U0
452+
453+
# marginals
454+
u0 = (1. / n) * np.ones(n)
455+
456+
# print iterations
457+
if verbose > 0:
458+
print('{:4s}|{:13s}|{:12s}|{:12s}'.format('It.', 'Loss', 'Crit.', 'Thres.') + '\n' + '-' * 40)
459+
460+
def compute_loss(M, pi, reg):
461+
return np.sum(M * pi) + reg * np.sum(pi * (np.log(pi) - 1))
462+
463+
def grassmann_distance(U1, U2):
464+
proj = U1.T @ U2
465+
_, s, _ = np.linalg.svd(proj)
466+
s[s > 1] = 1
467+
s = np.arccos(s)
468+
return np.linalg.norm(s)
469+
470+
# loop
471+
it = 0
472+
crit = np.inf
473+
sinkhorn_warmstart = None
474+
475+
while (it < maxiter) and (crit > stopThr):
476+
U_old = U
477+
478+
# Solve transport
479+
M = dist_utils(X, (X @ U) @ U.T)
480+
pi, log_sinkhorn = sinkhorn_bregman(
481+
u0, u0, M, reg,
482+
numItermax=maxiter_sink,
483+
method=sinkhorn_method, warmstart=sinkhorn_warmstart,
484+
warn=False, log=True
485+
)
486+
key_warmstart = 'warmstart'
487+
if key_warmstart in log_sinkhorn:
488+
sinkhorn_warmstart = log_sinkhorn[key_warmstart]
489+
if (pi >= 1e-300).all():
490+
loss = compute_loss(M, pi, reg)
491+
else:
492+
loss = np.inf
493+
494+
# Solve PCA
495+
pi_sym = (pi + pi.T) / 2
496+
497+
if method == 'BCD':
498+
# block coordinate descent
499+
S = X.T @ (2 * pi_sym - (1. / n) * np.eye(n)) @ X
500+
_, U = np.linalg.eigh(S)
501+
U = U[:, ::-1][:, :k]
502+
503+
elif method == 'MM':
504+
# majorization-minimization
505+
eig, _ = np.linalg.eigh(pi_sym)
506+
lambda_pi = eig[0]
507+
508+
for _ in range(maxiter_MM):
509+
X_proj = X @ U
510+
X_T_X_proj = X.T @ X_proj
511+
512+
R = (1 / n) * X_T_X_proj
513+
alpha = 1 - 2 * n * lambda_pi
514+
if alpha > 0:
515+
R = alpha * (R - lambda_scm * U)
516+
else:
517+
R = alpha * R
518+
519+
R = R - (2 * X.T @ (pi_sym @ X_proj)) + (2 * lambda_pi * X_T_X_proj)
520+
U, _ = np.linalg.qr(R)
521+
522+
else:
523+
raise ValueError(f"Unknown method '{method}', use 'BCD' or 'MM'.")
524+
525+
# stop or not
526+
it += 1
527+
crit = grassmann_distance(U_old, U)
528+
529+
# print
530+
if verbose > 0:
531+
print('{:4d}|{:8e}|{:8e}|{:8e}'.format(it, loss, crit, stopThr))
532+
533+
return pi, U

0 commit comments

Comments
 (0)