3
3
4
4
Implementation of :
5
5
Sparsity-Constrained Optimal Transport.
6
- Tianlin Liu, Joan Puigcerver, Mathieu Blondel.
7
- In Proc. of AISTATS 2018.
8
- https://arxiv.org/abs/1710.06276
6
+ Liu, T., Puigcerver, J., & Blondel, M. (2023).
7
+ Sparsity-constrained optimal transport.
8
+ Proceedings of the Eleventh International Conference on
9
+ Learning Representations (ICLR).
10
+ https://arxiv.org/abs/2209.15466
9
11
10
12
[50] Liu, T., Puigcerver, J., & Blondel, M. (2023).
11
13
Sparsity-constrained optimal transport.
23
25
from .backend import get_backend
24
26
25
27
28
+ def projection_sparse_simplex (V , max_nz , z = 1 , axis = None ):
29
+ r"""Projection of :math:`\mathbf{V}` onto the simplex with cardinality constraint (maximum number of non-zero elements) and then scaled by `z`.
30
+
31
+ .. math::
32
+ P\left(\mathbf{V}, max_nz, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z} \\ ||p||_0 \le \text{max_nz}} \quad \|\mathbf{y} - \mathbf{V}\|^2
33
+
34
+ Parameters
35
+ ----------
36
+ V: ndarray, rank 2
37
+ z: float or array
38
+ If array, len(z) must be compatible with :math:`\mathbf{V}`
39
+ axis: None or int
40
+ - axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), max_nz, z)`
41
+ - axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, max_nz, z_i)`
42
+ - axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, max_nz, z_j)`
43
+
44
+ Returns
45
+ -------
46
+ projection: ndarray, shape :math:`\mathbf{V}`.shape
47
+
48
+ References:
49
+ Sparse projections onto the simplex
50
+ Anastasios Kyrillidis, Stephen Becker, Volkan Cevher and, Christoph Koch
51
+ ICML 2013
52
+ https://arxiv.org/abs/1206.1529
53
+ """
54
+ if axis == 1 :
55
+ max_nz_indices = np .argpartition (
56
+ V ,
57
+ kth = - max_nz ,
58
+ axis = 1 )[:, - max_nz :]
59
+ # Record nonzero column indices in a descending order
60
+ max_nz_indices = max_nz_indices [:, ::- 1 ]
61
+
62
+ row_indices = np .arange (V .shape [0 ])[:, np .newaxis ]
63
+
64
+ # Extract the top max_nz values for each row
65
+ # and then project to simplex.
66
+ U = V [row_indices , max_nz_indices ]
67
+ z = np .ones (len (U )) * z
68
+ cssv = np .cumsum (U , axis = 1 ) - z [:, np .newaxis ]
69
+ ind = np .arange (max_nz ) + 1
70
+ cond = U - cssv / ind > 0
71
+ rho = np .count_nonzero (cond , axis = 1 )
72
+ theta = cssv [np .arange (len (U )), rho - 1 ] / rho
73
+ nz_projection = np .maximum (U - theta [:, np .newaxis ], 0 )
74
+
75
+ # Put the projection of max_nz_values to their original column indices
76
+ # while keeping other values zero.
77
+ sparse_projection = np .zeros_like (V )
78
+ sparse_projection [row_indices , max_nz_indices ] = nz_projection
79
+ return sparse_projection
80
+
81
+ elif axis == 0 :
82
+ return projection_sparse_simplex (V .T , max_nz , z , axis = 1 ).T
83
+
84
+ else :
85
+ V = V .ravel ().reshape (1 , - 1 )
86
+ return projection_sparse_simplex (V , max_nz , z , axis = 1 ).ravel ()
87
+
88
+
26
89
class SparsityConstrained (ot .smooth .Regularization ):
27
90
""" Squared L2 regularization with sparsity constraints """
28
91
@@ -42,22 +105,9 @@ def delta_Omega(self, X):
42
105
return val , G
43
106
44
107
def max_Omega (self , X , b ):
45
- # For each column of X, find top max_nz values and
46
- # their corresponding indices.
47
- max_nz_indices = np .argpartition (
48
- X ,
49
- kth = - self .max_nz ,
50
- axis = 0 )[- self .max_nz :]
51
- max_nz_values = X [max_nz_indices , np .arange (X .shape [1 ])]
52
-
53
- # Project the top max_nz values onto the simplex.
54
- G_nz_values = ot .smooth .projection_simplex (
55
- max_nz_values / (b * self .gamma ), axis = 0 )
56
-
57
- # Put the projection of max_nz_values to their original indices
58
- # and set all other values zero.
59
- G = np .zeros_like (X )
60
- G [max_nz_indices , np .arange (X .shape [1 ])] = G_nz_values
108
+ # Project the scaled X onto the simplex with sparsity constraint.
109
+ G = projection_sparse_simplex (
110
+ X / (b * self .gamma ), self .max_nz , axis = 0 )
61
111
val = np .sum (X * G , axis = 0 )
62
112
val -= 0.5 * self .gamma * b * np .sum (G * G , axis = 0 )
63
113
return val , G
0 commit comments