Skip to content

Commit eeaca57

Browse files
add sparsity-constrained ot funtionality and example
1 parent 2bbfbbb commit eeaca57

File tree

6 files changed

+389
-1
lines changed

6 files changed

+389
-1
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,5 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
308308
[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022.
309309

310310
[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33.
311+
312+
[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR).

examples/plot_OT_1D_smooth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
pl.show()
102102

103103

104-
#%% Smooth OT with KL regularization
104+
#%% Smooth OT with squared l2 regularization
105105

106106
lambd = 1e-1
107107
Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
================================
4+
Sparsity-constrained optimal transport example
5+
================================
6+
7+
This example illustrates EMD, squared l2 regularized OT, and sparsity-constrained OT plans.
8+
The sparsity-constrained OT can be considered as a middle ground between EMD and squared l2 regularized OT.
9+
10+
"""
11+
12+
# Author: Tianlin Liu <t.liu@unibas.ch>
13+
#
14+
# License: MIT License
15+
16+
# sphinx_gallery_thumbnail_number = 5
17+
18+
import numpy as np
19+
import matplotlib.pylab as pl
20+
import ot
21+
import ot.plot
22+
from ot.datasets import make_1D_gauss as gauss
23+
24+
##############################################################################
25+
# Generate data
26+
# -------------
27+
28+
29+
#%% parameters
30+
31+
n = 100 # nb bins
32+
33+
# bin positions
34+
x = np.arange(n, dtype=np.float64)
35+
36+
# Gaussian distributions
37+
a = gauss(n, m=20, s=5) # m= mean, s= std
38+
b = gauss(n, m=60, s=10)
39+
40+
# loss matrix
41+
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
42+
M /= M.max()
43+
44+
45+
##############################################################################
46+
# Plot distributions and loss matrix
47+
# ----------------------------------
48+
49+
#%% plot the distributions
50+
51+
pl.figure(1, figsize=(6.4, 3))
52+
pl.plot(x, a, 'b', label='Source distribution')
53+
pl.plot(x, b, 'r', label='Target distribution')
54+
pl.legend()
55+
56+
#%% plot distributions and loss matrix
57+
58+
pl.figure(2, figsize=(5, 5))
59+
ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
60+
61+
62+
#%% EMD
63+
64+
# use fast 1D solver
65+
G0 = ot.emd_1d(x, x, a, b)
66+
67+
# Equivalent to
68+
# G0 = ot.emd(a, b, M)
69+
70+
pl.figure(3, figsize=(5, 5))
71+
ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')
72+
73+
74+
#%% Smooth OT with squared l2 regularization
75+
76+
lambd = 1e-1
77+
Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')
78+
79+
pl.figure(4, figsize=(5, 5))
80+
ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.')
81+
82+
pl.show()
83+
84+
85+
#%% Smooth OT with squared l2 regularization
86+
87+
lambd = 1e-1
88+
Gsc = ot.sparse.sparsity_constrained_ot_dual(a, b, M, lambd, max_nz=2)
89+
pl.figure(5, figsize=(5, 5))
90+
ot.plot.plot1D_mat(a, b, Gsc, 'Sparsity contrained OT matrix; k=2.')
91+
92+
pl.show()
93+
94+
# %%

ot/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from . import gromov
2828
from . import smooth
2929
from . import stochastic
30+
from . import sparse
3031
from . import unbalanced
3132
from . import partial
3233
from . import backend

ot/sparse.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
"""
2+
Sparsity-constrained optimal transport solvers.
3+
4+
Implementation of :
5+
Sparsity-Constrained Optimal Transport.
6+
Tianlin Liu, Joan Puigcerver, Mathieu Blondel.
7+
In Proc. of AISTATS 2018.
8+
https://arxiv.org/abs/1710.06276
9+
10+
[50] Liu, T., Puigcerver, J., & Blondel, M. (2023).
11+
Sparsity-constrained optimal transport.
12+
Proceedings of the Eleventh International Conference on
13+
Learning Representations (ICLR).
14+
"""
15+
16+
# Author: Tianlin Liu <t.liu@unibas.ch>
17+
#
18+
# License: MIT License
19+
20+
21+
import numpy as np
22+
import ot
23+
from .backend import get_backend
24+
25+
26+
class SparsityConstrained(ot.smooth.Regularization):
27+
""" Squared L2 regularization with sparsity constraints """
28+
29+
def __init__(self, max_nz, gamma=1.0):
30+
self.max_nz = max_nz
31+
self.gamma = gamma
32+
33+
def delta_Omega(self, X):
34+
# For each column of X, find entries that are not among the top max_nz.
35+
non_top_indices = np.argpartition(
36+
-X, self.max_nz, axis=0)[self.max_nz:]
37+
# Set these entries to -inf.
38+
X[non_top_indices, np.arange(X.shape[1])] = -np.inf
39+
max_X = np.maximum(X, 0)
40+
val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma)
41+
G = max_X / self.gamma
42+
return val, G
43+
44+
def max_Omega(self, X, b):
45+
# For each column of X, find top max_nz values and
46+
# their corresponding indices.
47+
max_nz_indices = np.argpartition(
48+
X,
49+
kth=-self.max_nz,
50+
axis=0)[-self.max_nz:]
51+
max_nz_values = X[max_nz_indices, np.arange(X.shape[1])]
52+
53+
# Project the top max_nz values onto the simplex.
54+
G_nz_values = ot.smooth.projection_simplex(
55+
max_nz_values / (b * self.gamma), axis=0)
56+
57+
# Put the projection of max_nz_values to their original indices
58+
# and set all other values zero.
59+
G = np.zeros_like(X)
60+
G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values
61+
val = np.sum(X * G, axis=0)
62+
val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0)
63+
return val, G
64+
65+
def Omega(self, T):
66+
return 0.5 * self.gamma * np.sum(T ** 2)
67+
68+
69+
def sparsity_constrained_ot_dual(
70+
a, b, M, reg, max_nz,
71+
method="L-BFGS-B", stopThr=1e-9,
72+
numItermax=500, verbose=False, log=False):
73+
r"""
74+
Solve the sparsity-constrained OT problem in the dual and return the OT matrix.
75+
76+
The function solves the sparsity-contrained OT in dual formulation in
77+
:ref:`[50] <references-sparsity-constrained-ot-dual>`.
78+
79+
80+
Parameters
81+
----------
82+
a : np.ndarray (ns,)
83+
samples weights in the source domain
84+
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
85+
samples in the target domain, compute sinkhorn with multiple targets
86+
and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
87+
(return OT loss + dual variables in log)
88+
M : np.ndarray (ns,nt)
89+
loss matrix
90+
reg : float
91+
Regularization term >0
92+
max_nz: int
93+
Maximum number of non-zero entries permitted in each column of the
94+
optimal transport matrix.
95+
method : str
96+
Solver to use for scipy.optimize.minimize
97+
numItermax : int, optional
98+
Max number of iterations
99+
stopThr : float, optional
100+
Stop threshold on error (>0)
101+
verbose : bool, optional
102+
Print information along iterations
103+
log : bool, optional
104+
record log if True
105+
106+
107+
Returns
108+
-------
109+
gamma : (ns, nt) ndarray
110+
Optimal transportation matrix for the given parameters
111+
log : dict
112+
log dictionary return only if log==True in parameters
113+
114+
115+
.. _references-sparsity-constrained-ot-dual:
116+
References
117+
----------
118+
.. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR).
119+
120+
See Also
121+
--------
122+
ot.lp.emd : Unregularized OT
123+
ot.sinhorn : Entropic regularized OT
124+
ot.smooth : Entropic regularized and squared l2 regularized OT
125+
ot.optim.cg : General regularized OT
126+
127+
"""
128+
129+
nx = get_backend(a, b, M)
130+
max_nz = min(max_nz, M.shape[0])
131+
regul = SparsityConstrained(gamma=reg, max_nz=max_nz)
132+
133+
a0, b0, M0 = a, b, M
134+
135+
# convert to humpy
136+
a, b, M = nx.to_numpy(a, b, M)
137+
138+
# solve dual
139+
alpha, beta, res = ot.smooth.solve_dual(
140+
a, b, M, regul,
141+
max_iter=numItermax,
142+
tol=stopThr, verbose=verbose)
143+
144+
# reconstruct transport matrix
145+
G = nx.from_numpy(ot.smooth.get_plan_from_dual(alpha, beta, M, regul),
146+
type_as=M0)
147+
148+
if log:
149+
log = {'alpha': nx.from_numpy(alpha, type_as=a0),
150+
'beta': nx.from_numpy(beta, type_as=b0), 'res': res}
151+
return G, log
152+
else:
153+
return G
154+
155+
156+
def sparsity_constrained_ot_semi_dual(
157+
a, b, M, reg, max_nz,
158+
method="L-BFGS-B", stopThr=1e-9,
159+
numItermax=500, verbose=False, log=False):
160+
r"""
161+
Solve the regularized OT problem in the semi-dual and return the OT matrix
162+
163+
The function solves the sparsity-contrained OT in semi-dual formulation in
164+
:ref:`[50] <references-sparsity-constrained-ot-semi-dual>`.
165+
166+
167+
Parameters
168+
----------
169+
a : np.ndarray (ns,)
170+
samples weights in the source domain
171+
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
172+
samples in the target domain, compute sinkhorn with multiple targets
173+
and fixed:math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
174+
(return OT loss + dual variables in log)
175+
M : np.ndarray (ns,nt)
176+
loss matrix
177+
reg : float
178+
Regularization term >0
179+
max_nz: int
180+
Maximum number of non-zero entries permitted in each column of the optimal transport matrix.
181+
method : str
182+
Solver to use for scipy.optimize.minimize
183+
numItermax : int, optional
184+
Max number of iterations
185+
stopThr : float, optional
186+
Stop threshold on error (>0)
187+
verbose : bool, optional
188+
Print information along iterations
189+
log : bool, optional
190+
record log if True
191+
192+
193+
Returns
194+
-------
195+
gamma : (ns, nt) ndarray
196+
Optimal transportation matrix for the given parameters
197+
log : dict
198+
log dictionary return only if log==True in parameters
199+
200+
201+
.. _references-sparsity-constrained-ot-semi-dual:
202+
References
203+
----------
204+
.. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR).
205+
206+
See Also
207+
--------
208+
ot.lp.emd : Unregularized OT
209+
ot.sinhorn : Entropic regularized OT
210+
ot.smooth : Entropic regularized and squared l2 regularized OT
211+
ot.optim.cg : General regularized OT
212+
213+
"""
214+
215+
max_nz = min(max_nz, M.shape[0])
216+
regul = SparsityConstrained(gamma=reg, max_nz=max_nz)
217+
# solve dual
218+
alpha, res = ot.smooth.solve_semi_dual(
219+
a, b, M, regul, max_iter=numItermax,
220+
tol=stopThr, verbose=verbose)
221+
222+
# reconstruct transport matrix
223+
G = ot.smooth.get_plan_from_semi_dual(alpha, b, M, regul)
224+
225+
if log:
226+
log = {'alpha': alpha, 'res': res}
227+
return G, log
228+
else:
229+
return G

0 commit comments

Comments
 (0)