Skip to content

Commit e8bb4e0

Browse files
correct typos; add projection_sparse_simplex
1 parent eeaca57 commit e8bb4e0

File tree

2 files changed

+123
-20
lines changed

2 files changed

+123
-20
lines changed

ot/sparse.py

Lines changed: 69 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
44
Implementation of :
55
Sparsity-Constrained Optimal Transport.
6-
Tianlin Liu, Joan Puigcerver, Mathieu Blondel.
7-
In Proc. of AISTATS 2018.
8-
https://arxiv.org/abs/1710.06276
6+
Liu, T., Puigcerver, J., & Blondel, M. (2023).
7+
Sparsity-constrained optimal transport.
8+
Proceedings of the Eleventh International Conference on
9+
Learning Representations (ICLR).
10+
https://arxiv.org/abs/2209.15466
911
1012
[50] Liu, T., Puigcerver, J., & Blondel, M. (2023).
1113
Sparsity-constrained optimal transport.
@@ -23,6 +25,67 @@
2325
from .backend import get_backend
2426

2527

28+
def projection_sparse_simplex(V, max_nz, z=1, axis=None):
29+
r"""Projection of :math:`\mathbf{V}` onto the simplex with cardinality constraint (maximum number of non-zero elements) and then scaled by `z`.
30+
31+
.. math::
32+
P\left(\mathbf{V}, max_nz, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z} \\ ||p||_0 \le \text{max_nz}} \quad \|\mathbf{y} - \mathbf{V}\|^2
33+
34+
Parameters
35+
----------
36+
V: ndarray, rank 2
37+
z: float or array
38+
If array, len(z) must be compatible with :math:`\mathbf{V}`
39+
axis: None or int
40+
- axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), max_nz, z)`
41+
- axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, max_nz, z_i)`
42+
- axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, max_nz, z_j)`
43+
44+
Returns
45+
-------
46+
projection: ndarray, shape :math:`\mathbf{V}`.shape
47+
48+
References:
49+
Sparse projections onto the simplex
50+
Anastasios Kyrillidis, Stephen Becker, Volkan Cevher and, Christoph Koch
51+
ICML 2013
52+
https://arxiv.org/abs/1206.1529
53+
"""
54+
if axis == 1:
55+
max_nz_indices = np.argpartition(
56+
V,
57+
kth=-max_nz,
58+
axis=1)[:, -max_nz:]
59+
# Record nonzero column indices in a descending order
60+
max_nz_indices = max_nz_indices[:, ::-1]
61+
62+
row_indices = np.arange(V.shape[0])[:, np.newaxis]
63+
64+
# Extract the top max_nz values for each row
65+
# and then project to simplex.
66+
U = V[row_indices, max_nz_indices]
67+
z = np.ones(len(U)) * z
68+
cssv = np.cumsum(U, axis=1) - z[:, np.newaxis]
69+
ind = np.arange(max_nz) + 1
70+
cond = U - cssv / ind > 0
71+
rho = np.count_nonzero(cond, axis=1)
72+
theta = cssv[np.arange(len(U)), rho - 1] / rho
73+
nz_projection = np.maximum(U - theta[:, np.newaxis], 0)
74+
75+
# Put the projection of max_nz_values to their original column indices
76+
# while keeping other values zero.
77+
sparse_projection = np.zeros_like(V)
78+
sparse_projection[row_indices, max_nz_indices] = nz_projection
79+
return sparse_projection
80+
81+
elif axis == 0:
82+
return projection_sparse_simplex(V.T, max_nz, z, axis=1).T
83+
84+
else:
85+
V = V.ravel().reshape(1, -1)
86+
return projection_sparse_simplex(V, max_nz, z, axis=1).ravel()
87+
88+
2689
class SparsityConstrained(ot.smooth.Regularization):
2790
""" Squared L2 regularization with sparsity constraints """
2891

@@ -42,22 +105,9 @@ def delta_Omega(self, X):
42105
return val, G
43106

44107
def max_Omega(self, X, b):
45-
# For each column of X, find top max_nz values and
46-
# their corresponding indices.
47-
max_nz_indices = np.argpartition(
48-
X,
49-
kth=-self.max_nz,
50-
axis=0)[-self.max_nz:]
51-
max_nz_values = X[max_nz_indices, np.arange(X.shape[1])]
52-
53-
# Project the top max_nz values onto the simplex.
54-
G_nz_values = ot.smooth.projection_simplex(
55-
max_nz_values / (b * self.gamma), axis=0)
56-
57-
# Put the projection of max_nz_values to their original indices
58-
# and set all other values zero.
59-
G = np.zeros_like(X)
60-
G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values
108+
# Project the scaled X onto the simplex with sparsity constraint.
109+
G = projection_sparse_simplex(
110+
X / (b * self.gamma), self.max_nz, axis=0)
61111
val = np.sum(X * G, axis=0)
62112
val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0)
63113
return val, G

test/test_sparse.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Tests for ot.smooth model """
1+
"""Tests for ot.sparse model """
22

33
# Author: Tianlin Liu <t.liu@unibas.ch>
44
#
@@ -60,3 +60,56 @@ def test_sparsity_constrained_ot_semi_dual():
6060
np.testing.assert_array_less(
6161
np.sum(plan > 0, axis=0),
6262
np.ones(n) * max_nz + 1)
63+
64+
65+
def test_projection_sparse_simplex():
66+
67+
def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None):
68+
r"""This is an equivalent but less efficient version
69+
of ot.sparse.projection_sparse_simplex, as it uses two
70+
sorts instead of one.
71+
"""
72+
73+
if axis == 0:
74+
# For each column of X, find top max_nz values and
75+
# their corresponding indices. This incurs a sort.
76+
max_nz_indices = np.argpartition(
77+
X,
78+
kth=-max_nz,
79+
axis=0)[-max_nz:]
80+
81+
max_nz_values = X[max_nz_indices, np.arange(X.shape[1])]
82+
83+
# Project the top max_nz values onto the simplex.
84+
# This incurs a second sort.
85+
G_nz_values = ot.smooth.projection_simplex(
86+
max_nz_values, z=z, axis=0)
87+
88+
# Put the projection of max_nz_values to their original indices
89+
# and set all other values zero.
90+
G = np.zeros_like(X)
91+
G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values
92+
return G
93+
elif axis == 1:
94+
return double_sort_projection_sparse_simplex(
95+
X.T, max_nz, z, axis=0).T
96+
97+
else:
98+
X = X.ravel().reshape(-1, 1)
99+
return double_sort_projection_sparse_simplex(
100+
X, max_nz, z, axis=0).ravel()
101+
102+
m, n = 5, 10
103+
rng = np.random.RandomState(0)
104+
X = rng.uniform(size=(m, n))
105+
max_nz = 3
106+
107+
for axis in [0, 1, None]:
108+
slow_sparse_proj = double_sort_projection_sparse_simplex(
109+
X, max_nz, axis=axis)
110+
fast_sparse_proj = ot.sparse.projection_sparse_simplex(
111+
X, max_nz, axis=axis)
112+
113+
# check that two versions produce the same result
114+
np.testing.assert_allclose(
115+
slow_sparse_proj, fast_sparse_proj)

0 commit comments

Comments
 (0)