Skip to content

Commit 7af8c21

Browse files
haoran010haoran010rflamary
authored
[MRG] Regularization path for l2 UOT (#274)
* add reg path * debug examples and verify pep8 * pep8 and move the reg path examples in unbalanced folder Co-authored-by: haoran010 <haoran.wu@insa-rennes.fr> Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent d50d814 commit 7af8c21

File tree

4 files changed

+1028
-1
lines changed

4 files changed

+1028
-1
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
================================================================
4+
Regularization path of l2-penalized unbalanced optimal transport
5+
================================================================
6+
This example illustrate the regularization path for 2D unbalanced
7+
optimal transport. We present here both the fully relaxed case
8+
and the semi-relaxed case.
9+
10+
[Chapel et al., 2021] Chapel, L., Flamary, R., Wu, H., Févotte, C.,
11+
and Gasso, G. (2021). Unbalanced optimal transport through non-negative
12+
penalized linear regression.
13+
"""
14+
15+
# Author: Haoran Wu <haoran.wu@univ-ubs.fr>
16+
# License: MIT License
17+
18+
19+
import numpy as np
20+
import matplotlib.pylab as pl
21+
import ot
22+
23+
##############################################################################
24+
# Generate data
25+
# -------------
26+
27+
#%% parameters and data generation
28+
29+
n = 50 # nb samples
30+
31+
mu_s = np.array([-1, -1])
32+
cov_s = np.array([[1, 0], [0, 1]])
33+
34+
mu_t = np.array([4, 4])
35+
cov_t = np.array([[1, -.8], [-.8, 1]])
36+
37+
np.random.seed(0)
38+
xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
39+
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
40+
41+
a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples
42+
43+
# loss matrix
44+
M = ot.dist(xs, xt)
45+
M /= M.max()
46+
47+
##############################################################################
48+
# Plot data
49+
# ---------
50+
51+
#%% plot 2 distribution samples
52+
53+
pl.figure(1)
54+
pl.scatter(xs[:, 0], xs[:, 1], c='C0', label='Source')
55+
pl.scatter(xt[:, 0], xt[:, 1], c='C1', label='Target')
56+
pl.legend(loc=2)
57+
pl.title('Source and target distributions')
58+
pl.show()
59+
60+
##############################################################################
61+
# Compute semi-relaxed and fully relaxed regularization paths
62+
# -----------
63+
64+
#%%
65+
final_gamma = 1e-8
66+
t, t_list, g_list = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
67+
semi_relaxed=False)
68+
t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
69+
semi_relaxed=True)
70+
71+
72+
##############################################################################
73+
# Plot the regularization path
74+
# ----------------
75+
76+
#%% fully relaxed l2-penalized UOT
77+
78+
pl.figure(2)
79+
selected_gamma = [2e-1, 1e-1, 5e-2, 1e-3]
80+
for p in range(4):
81+
tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list,
82+
t_list)
83+
P = tp.reshape((n, n))
84+
pl.subplot(2, 2, p + 1)
85+
if P.sum() > 0:
86+
P = P / P.max()
87+
for i in range(n):
88+
for j in range(n):
89+
if P[i, j] > 0:
90+
pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
91+
alpha=P[i, j] * 0.3)
92+
pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
93+
pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=0.2)
94+
pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * (1 + p) * 2,
95+
label='Re-weighted source', alpha=1)
96+
pl.scatter(xt[:, 0], xt[:, 1], c='C1', s=P.sum(0).ravel() * (1 + p) * 2,
97+
label='Re-weighted target', alpha=1)
98+
pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
99+
pl.title(r'$\ell_2$ UOT $\gamma$={}'.format(selected_gamma[p]),
100+
fontsize=11)
101+
if p < 2:
102+
pl.xticks(())
103+
pl.show()
104+
105+
106+
##############################################################################
107+
# Plot the semi-relaxed regularization path
108+
# -------------------
109+
110+
#%% semi-relaxed l2-penalized UOT
111+
112+
pl.figure(3)
113+
selected_gamma = [10, 1, 1e-1, 1e-2]
114+
for p in range(4):
115+
tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2,
116+
t_list2)
117+
P = tp.reshape((n, n))
118+
pl.subplot(2, 2, p + 1)
119+
if P.sum() > 0:
120+
P = P / P.max()
121+
for i in range(n):
122+
for j in range(n):
123+
if P[i, j] > 0:
124+
pl.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], color='C2',
125+
alpha=P[i, j] * 0.3)
126+
pl.scatter(xs[:, 0], xs[:, 1], c='C0', alpha=0.2)
127+
pl.scatter(xt[:, 0], xt[:, 1], c='C1', alpha=1, label='Target marginal')
128+
pl.scatter(xs[:, 0], xs[:, 1], c='C0', s=P.sum(1).ravel() * 2 * (1 + p),
129+
label='Source marginal', alpha=1)
130+
pl.plot([], [], color='C2', alpha=0.8, label='OT plan')
131+
pl.title(r'Semi-relaxed $l_2$ UOT $\gamma$={}'.format(selected_gamma[p]),
132+
fontsize=11)
133+
if p < 2:
134+
pl.xticks(())
135+
pl.show()

ot/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from . import unbalanced
3535
from . import partial
3636
from . import backend
37+
from . import regpath
3738

3839
# OT functions
3940
from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d
@@ -54,4 +55,4 @@
5455
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
5556
'sinkhorn_unbalanced', 'barycenter_unbalanced',
5657
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance',
57-
'smooth', 'stochastic', 'unbalanced', 'partial']
58+
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath']

0 commit comments

Comments
 (0)