Skip to content

Commit e0ba31c

Browse files
Hv0nnusrflamary
andauthored
[MRG] Implementation of two news algorithms: SaGroW and PoGroW. (#275)
* Add two new algorithms to solve Gromov Wasserstein: Sampled Gromov Wasserstein and Pointwise Gromov Wasserstein. * Correct some lines in SaGroW and PoGroW to follow pep8 guide. * Change nb_samples name. Use rdm state. Change symmetric check. * Change names of len(p) and len(q) in SaGroW and PoGroW. * Re-add some deleted lines in the comments of gromov.py Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 96bf1a4 commit e0ba31c

File tree

4 files changed

+496
-6
lines changed

4 files changed

+496
-6
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ POT provides the following generic OT solvers (links to examples):
2828
* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12])
2929
* [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24]
3030
* [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
31+
* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
3132
* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20].
3233
* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25].
3334
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
@@ -198,6 +199,7 @@ The contributors to this library are
198199
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
199200
* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
200201
* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
202+
* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein)
201203
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)
202204

203205
This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):
@@ -286,3 +288,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
286288
[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
287289

288290
[32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML).
291+
292+
[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021

examples/gromov/plot_gromov.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,37 @@
104104
pl.title('Entropic Gromov Wasserstein')
105105

106106
pl.show()
107+
108+
#############################################################################
109+
#
110+
# Compute GW with a scalable stochastic method with any loss function
111+
# ----------------------------------------------------------------------
112+
113+
114+
def loss(x, y):
115+
return np.abs(x - y)
116+
117+
118+
pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100,
119+
log=True)
120+
121+
sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100,
122+
log=True)
123+
124+
print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated']))
125+
print('Variance estimated: ' + str(plog['gw_dist_std']))
126+
print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated']))
127+
print('Variance estimated: ' + str(slog['gw_dist_std']))
128+
129+
130+
pl.figure(1, (10, 5))
131+
132+
pl.subplot(1, 2, 1)
133+
pl.imshow(pgw.toarray(), cmap='jet')
134+
pl.title('Pointwise Gromov Wasserstein')
135+
136+
pl.subplot(1, 2, 2)
137+
pl.imshow(sgw, cmap='jet')
138+
pl.title('Sampled Gromov Wasserstein')
139+
140+
pl.show()

0 commit comments

Comments
 (0)