You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+3Lines changed: 3 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -17,6 +17,7 @@ It provides the following solvers:
17
17
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat).
18
18
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
19
19
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
20
+
* Non regularized free support Wasserstein barycenters [20].
20
21
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
21
22
* Optimal transport for domain adaptation with group lasso regularization [5]
22
23
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
@@ -225,3 +226,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
225
226
[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) [Stochastic Optimization for Large-scale Optimal Transport](arXiv preprint arxiv:1605.08527). Advances in Neural Information Processing Systems (2016).
226
227
227
228
[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet, A.& Blondel, M. [Large-scale Optimal Transport and Mapping Estimation](https://arxiv.org/pdf/1711.02283.pdf). International Conference on Learning Representation (2018)
229
+
230
+
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
226
+
227
+
The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms.
228
+
This problem is considered in [1] (Algorithm 2). There are two differences with the following codes:
229
+
- we do not optimize over the weights
230
+
- we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting.
231
+
232
+
Parameters
233
+
----------
234
+
measures_locations : list of (k_i,d) np.ndarray
235
+
The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list)
236
+
measures_weights : list of (k_i,) np.ndarray
237
+
Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure
238
+
239
+
X_init : (k,d) np.ndarray
240
+
Initialization of the support locations (on k atoms) of the barycenter
241
+
b : (k,) np.ndarray
242
+
Initialization of the weights of the barycenter (non-negatives, sum to 1)
243
+
weights : (k,) np.ndarray
244
+
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
245
+
246
+
numItermax : int, optional
247
+
Max number of iterations
248
+
stopThr : float, optional
249
+
Stop threshol on error (>0)
250
+
verbose : bool, optional
251
+
Print information along iterations
252
+
log : bool, optional
253
+
record log if True
254
+
255
+
Returns
256
+
-------
257
+
X : (k,d) np.ndarray
258
+
Support locations (on k atoms) of the barycenter
259
+
260
+
References
261
+
----------
262
+
263
+
.. [1] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
264
+
265
+
.. [2] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
266
+
267
+
"""
268
+
269
+
iter_count=0
270
+
271
+
N=len(measures_locations)
272
+
k=X_init.shape[0]
273
+
d=X_init.shape[1]
274
+
ifbisNone:
275
+
b=np.ones((k,))/k
276
+
ifweightsisNone:
277
+
weights=np.ones((N,)) /N
278
+
279
+
X=X_init
280
+
281
+
log_dict= {}
282
+
displacement_square_norms= []
283
+
284
+
displacement_square_norm=stopThr+1.
285
+
286
+
while ( displacement_square_norm>stopThranditer_count<numItermax ):
287
+
288
+
T_sum=np.zeros((k, d))
289
+
290
+
for (measure_locations_i, measure_weights_i, weight_i) inzip(measures_locations, measures_weights, weights.tolist()):
0 commit comments