diff --git a/README.md b/README.md index 25c04011f..00e2bb989 100644 --- a/README.md +++ b/README.md @@ -312,3 +312,5 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR). [51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019). [Gromov-wasserstein learning for graph matching and node embedding](http://proceedings.mlr.press/v97/xu19b.html). In International Conference on Machine Learning (ICML), 2019. + +[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). [Entropic Wasserstein Component Analysis](https://arxiv.org/abs/2303.05119). ArXiv. diff --git a/RELEASES.md b/RELEASES.md index 595fecfb5..bd74da059 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -15,6 +15,7 @@ - Added features `warmstartT` and `kwargs` to all CG and entropic (F)GW barycenter solvers (PR #455) - Added entropic semi-relaxed (Fused) Gromov-Wasserstein solvers in `ot.gromov` + examples (PR #455) - Make marginal parameters optional for (F)GW solvers in `._gw`, `._bregman` and `._semirelaxed` (PR #455) +- Add Entropic Wasserstein Component Analysis (ECWA) in ot.dr (PR #486) #### Closed issues diff --git a/examples/others/plot_EWCA.py b/examples/others/plot_EWCA.py new file mode 100644 index 000000000..fb9bd713f --- /dev/null +++ b/examples/others/plot_EWCA.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +""" +======================================= +Entropic Wasserstein Component Analysis +======================================= + +This example illustrates the use of EWCA as proposed in [52]. + + +[52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). +Entropic Wasserstein Component Analysis. + +""" + +# Author: Antoine Collas +# +# License: MIT License +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +from ot.dr import ewca +from sklearn.datasets import make_blobs +from matplotlib import ticker as mticker +import matplotlib.patches as patches +import matplotlib + +############################################################################## +# Generate data +# ------------- + +n_samples = 20 +esp = 0.8 +centers = np.array([[esp, esp], [-esp, -esp]]) +cluster_std = 0.4 + +rng = np.random.RandomState(42) +X, y = make_blobs( + n_samples=n_samples, + n_features=2, + centers=centers, + cluster_std=cluster_std, + shuffle=False, + random_state=rng, +) +X = X - X.mean(0) + +############################################################################## +# Plot data +# ------------- + +fig = pl.figure(figsize=(4, 4)) +cmap = matplotlib.colormaps.get_cmap("tab10") +pl.scatter( + X[: n_samples // 2, 0], + X[: n_samples // 2, 1], + color=[cmap(y[i] + 1) for i in range(n_samples // 2)], + alpha=0.4, + label="Class 1", + zorder=30, + s=50, +) +pl.scatter( + X[n_samples // 2:, 0], + X[n_samples // 2:, 1], + color=[cmap(y[i] + 1) for i in range(n_samples // 2, n_samples)], + alpha=0.4, + label="Class 2", + zorder=30, + s=50, +) +x_y_lim = 2.5 +fs = 15 +pl.xlim(-x_y_lim, x_y_lim) +pl.xticks([]) +pl.ylim(-x_y_lim, x_y_lim) +pl.yticks([]) +pl.legend(fontsize=fs) +pl.title("Data", fontsize=fs) +pl.tight_layout() + + +############################################################################## +# Compute EWCA +# ------------- + +pi, U = ewca(X, k=2, reg=0.5) + + +############################################################################## +# Plot data, first component, and projected data +# ------------- + +fig = pl.figure(figsize=(4, 4)) + +scale = 3 +u = U[:, 0] +pl.plot( + [scale * u[0], -scale * u[0]], + [scale * u[1], -scale * u[1]], + color="grey", + linestyle="--", + lw=3, + alpha=0.3, + label=r"$\mathbf{U}$", +) +X1 = X @ u[:, None] @ u[:, None].T + +for i in range(n_samples): + for j in range(n_samples): + v = pi[i, j] / pi.max() + if v >= 0.15 or (i, j) == (n_samples - 1, n_samples - 1): + pl.plot( + [X[i, 0], X1[j, 0]], + [X[i, 1], X1[j, 1]], + alpha=v, + linestyle="-", + c="C0", + label=r"$\pi_{ij}$" + if (i, j) == (n_samples - 1, n_samples - 1) + else None, + ) +pl.scatter( + X[:, 0], + X[:, 1], + color=[cmap(y[i] + 1) for i in range(n_samples)], + alpha=0.4, + label=r"$\mathbf{x}_i$", + zorder=30, + s=50, +) +pl.scatter( + X1[:, 0], + X1[:, 1], + color=[cmap(y[i] + 1) for i in range(n_samples)], + alpha=0.9, + s=50, + marker="+", + label=r"$\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_i$", + zorder=30, +) +pl.title("Data and projections", fontsize=fs) +pl.xlim(-x_y_lim, x_y_lim) +pl.xticks([]) +pl.ylim(-x_y_lim, x_y_lim) +pl.yticks([]) +pl.legend(fontsize=fs, loc="upper left") +pl.tight_layout() + + +############################################################################## +# Plot transport plan +# ------------- + +fig = pl.figure(figsize=(5, 5)) + +norm = matplotlib.colors.PowerNorm(0.5, vmin=0, vmax=100) +im = pl.imshow(n_samples * pi * 100, cmap=pl.cm.Blues, norm=norm, aspect="auto") +cb = fig.colorbar(im, orientation="vertical", shrink=0.8) +ticks_loc = cb.ax.get_yticks().tolist() +cb.ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc)) +cb.ax.set_yticklabels([f"{int(i)}%" for i in cb.get_ticks()]) +cb.ax.tick_params(labelsize=fs) +for i, class_ in enumerate(np.sort(np.unique(y))): + indices = y == class_ + idx_min = np.min(np.arange(len(y))[indices]) + idx_max = np.max(np.arange(len(y))[indices]) + width = idx_max - idx_min + 1 + rect = patches.Rectangle( + (idx_min - 0.5, idx_min - 0.5), + width, + width, + linewidth=1, + edgecolor="r", + facecolor="none", + ) + pl.gca().add_patch(rect) + +pl.title("OT plan", fontsize=fs) +pl.ylabel(r"($\mathbf{x}_1, \cdots, \mathbf{x}_n$)") +x_label = r"($\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_1, \cdots," +x_label += r"\mathbf{U}\mathbf{U}^{\top}\mathbf{x}_n$)" +pl.xlabel(x_label) +pl.tight_layout() +pl.axis("scaled") + +pl.show() diff --git a/ot/dr.py b/ot/dr.py index 47c8733a2..cb5768fec 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -12,16 +12,21 @@ # Author: Remi Flamary # Minhui Huang # Jakub Zadrozny +# Antoine Collas # # License: MIT License from scipy import linalg import autograd.numpy as np +from sklearn.decomposition import PCA import pymanopt import pymanopt.manifolds import pymanopt.optimizers +from .bregman import sinkhorn as sinkhorn_bregman +from .utils import dist as dist_utils + def dist(x1, x2): r""" Compute squared euclidean distance between samples (autograd) @@ -376,3 +381,153 @@ def Vpi(X, Y, a, b, pi): iter = iter + 1 return pi, U + + +def ewca(X, U0=None, reg=1, k=2, method='BCD', sinkhorn_method='sinkhorn', stopThr=1e-6, maxiter=100, maxiter_sink=1000, maxiter_MM=10, verbose=0): + r""" + Entropic Wasserstein Component Analysis :ref:`[52] `. + + The function solves the following optimization problem: + + .. math:: + \mathbf{U} = \mathop{\arg \min}_\mathbf{U} \quad + W(\mathbf{X}, \mathbf{U}\mathbf{U}^T \mathbf{X}) + + where : + + - :math:`\mathbf{U}` is a matrix in the Stiefel(`p`, `d`) manifold + - :math:`W` is entropic regularized Wasserstein distances + - :math:`\mathbf{X}` are samples + + Parameters + ---------- + X : ndarray, shape (n, d) + Samples from measure :math:`\mu`. + U0 : ndarray, shape (d, k), optional + Initial starting point for projection. + reg : float, optional + Regularization term >0 (entropic regularization). + k : int, optional + Subspace dimension. + method : str, optional + Eather 'BCD' or 'MM' (Block Coordinate Descent or Majorization-Minimization). + Prefer MM when d is large. + sinkhorn_method : str + Method used for the Sinkhorn solver, see :ref:`ot.bregman.sinkhorn` for more details. + stopThr : float, optional + Stop threshold on error (>0). + maxiter : int, optional + Maximum number of iterations of the BCD/MM. + maxiter_sink : int, optional + Maximum number of iterations of the Sinkhorn solver. + maxiter_MM : int, optional + Maximum number of iterations of the MM (only used when method='MM'). + verbose : int, optional + Print information along iterations. + + Returns + ------- + pi : ndarray, shape (n, n) + Optimal transportation matrix for the given parameters. + U : ndarray, shape (d, k) + Matrix Stiefel manifold. + + + .. _references-entropic-wasserstein-component_analysis: + References + ---------- + .. [52] Collas, A., Vayer, T., Flamary, F., & Breloy, A. (2023). + Entropic Wasserstein Component Analysis. + """ # noqa + n, d = X.shape + X = X - X.mean(0) + + if U0 is None: + pca_fitted = PCA(n_components=k).fit(X) + U = pca_fitted.components_.T + if method == 'MM': + lambda_scm = pca_fitted.explained_variance_[0] + else: + U = U0 + + # marginals + u0 = (1. / n) * np.ones(n) + + # print iterations + if verbose > 0: + print('{:4s}|{:13s}|{:12s}|{:12s}'.format('It.', 'Loss', 'Crit.', 'Thres.') + '\n' + '-' * 40) + + def compute_loss(M, pi, reg): + return np.sum(M * pi) + reg * np.sum(pi * (np.log(pi) - 1)) + + def grassmann_distance(U1, U2): + proj = U1.T @ U2 + _, s, _ = np.linalg.svd(proj) + s[s > 1] = 1 + s = np.arccos(s) + return np.linalg.norm(s) + + # loop + it = 0 + crit = np.inf + sinkhorn_warmstart = None + + while (it < maxiter) and (crit > stopThr): + U_old = U + + # Solve transport + M = dist_utils(X, (X @ U) @ U.T) + pi, log_sinkhorn = sinkhorn_bregman( + u0, u0, M, reg, + numItermax=maxiter_sink, + method=sinkhorn_method, warmstart=sinkhorn_warmstart, + warn=False, log=True + ) + key_warmstart = 'warmstart' + if key_warmstart in log_sinkhorn: + sinkhorn_warmstart = log_sinkhorn[key_warmstart] + if (pi >= 1e-300).all(): + loss = compute_loss(M, pi, reg) + else: + loss = np.inf + + # Solve PCA + pi_sym = (pi + pi.T) / 2 + + if method == 'BCD': + # block coordinate descent + S = X.T @ (2 * pi_sym - (1. / n) * np.eye(n)) @ X + _, U = np.linalg.eigh(S) + U = U[:, ::-1][:, :k] + + elif method == 'MM': + # majorization-minimization + eig, _ = np.linalg.eigh(pi_sym) + lambda_pi = eig[0] + + for _ in range(maxiter_MM): + X_proj = X @ U + X_T_X_proj = X.T @ X_proj + + R = (1 / n) * X_T_X_proj + alpha = 1 - 2 * n * lambda_pi + if alpha > 0: + R = alpha * (R - lambda_scm * U) + else: + R = alpha * R + + R = R - (2 * X.T @ (pi_sym @ X_proj)) + (2 * lambda_pi * X_T_X_proj) + U, _ = np.linalg.qr(R) + + else: + raise ValueError(f"Unknown method '{method}', use 'BCD' or 'MM'.") + + # stop or not + it += 1 + crit = grassmann_distance(U_old, U) + + # print + if verbose > 0: + print('{:4d}|{:8e}|{:8e}|{:8e}'.format(it, loss, crit, stopThr)) + + return pi, U diff --git a/test/test_dr.py b/test/test_dr.py index 6d7fc9aa6..4f0d937ef 100644 --- a/test/test_dr.py +++ b/test/test_dr.py @@ -2,6 +2,7 @@ # Author: Remi Flamary # Minhui Huang +# Antoine Collas # # License: MIT License @@ -141,3 +142,58 @@ def fragmented_hypercube(n, d, dim): U0, _ = np.linalg.qr(U0) pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, U0=U0, reg=reg, k=k, maxiter=1000, verbose=1) + + +@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") +def test_ewca(): + + d = 5 + n_samples = 50 + k = 3 + np.random.seed(0) + + # generate gaussian dataset + A = np.random.normal(size=(d, d)) + Q, _ = np.linalg.qr(A) + D = np.random.normal(size=d) + D = (D / np.linalg.norm(D)) ** 4 + cov = Q @ np.diag(D) @ Q.T + X = np.random.multivariate_normal(np.zeros(d), cov, size=n_samples) + X = X - X.mean(0, keepdims=True) + assert X.shape == (n_samples, d) + + # compute first 3 components with BCD + pi, U = ot.dr.ewca(X, reg=0.01, method='BCD', k=k, verbose=1, sinkhorn_method='sinkhorn_log') + assert pi.shape == (n_samples, n_samples) + assert (pi >= 0).all() + assert np.allclose(pi.sum(0), 1 / n_samples, atol=1e-3) + assert np.allclose(pi.sum(1), 1 / n_samples, atol=1e-3) + assert U.shape == (d, k) + assert np.allclose(U.T @ U, np.eye(k), atol=1e-3) + + # test that U contains the principal components + U_first_eigvec = np.linalg.svd(X.T, full_matrices=False)[0][:, :k] + _, cos, _ = np.linalg.svd(U.T @ U_first_eigvec, full_matrices=False) + assert np.allclose(cos, np.ones(k), atol=1e-3) + + # compute first 3 components with MM + pi, U = ot.dr.ewca(X, reg=0.01, method='MM', k=k, verbose=1, sinkhorn_method='sinkhorn_log') + assert pi.shape == (n_samples, n_samples) + assert (pi >= 0).all() + assert np.allclose(pi.sum(0), 1 / n_samples, atol=1e-3) + assert np.allclose(pi.sum(1), 1 / n_samples, atol=1e-3) + assert U.shape == (d, k) + assert np.allclose(U.T @ U, np.eye(k), atol=1e-3) + + # test that U contains the principal components + U_first_eigvec = np.linalg.svd(X.T, full_matrices=False)[0][:, :k] + _, cos, _ = np.linalg.svd(U.T @ U_first_eigvec, full_matrices=False) + assert np.allclose(cos, np.ones(k), atol=1e-3) + + # compute last 3 components + pi, U = ot.dr.ewca(X, reg=100000, method='MM', k=k, verbose=1, sinkhorn_method='sinkhorn_log') + + # test that U contains the last principal components + U_last_eigvec = np.linalg.svd(X.T, full_matrices=False)[0][:, -k:] + _, cos, _ = np.linalg.svd(U.T @ U_last_eigvec, full_matrices=False) + assert np.allclose(cos, np.ones(k), atol=1e-3)