Skip to content

[MRG] Efficient Discrete Multi Marginal Optimal Transport #454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
0327b5c
add demd.py to ot, add plot_demd_*.py to examples, updated init.py in…
xzyu02 Apr 4, 2023
99f4ae3
Merge branch 'PythonOT:master' into demd
xzyu02 Apr 4, 2023
27878b7
update REAMDME.md with citation to iclr23 paper and example link
xzyu02 Apr 5, 2023
4a3a4f1
chaneg directory of examples, build successful
xzyu02 Apr 5, 2023
94e0f44
fix small latex bug
xzyu02 Apr 5, 2023
2957510
update all.rst, examples and demd have passed pep8 and pyflake
xzyu02 Apr 5, 2023
708b756
add more detailed comments for examples
xzyu02 Apr 5, 2023
7c813a7
TODO: test module for demd, wrong demd index after build
xzyu02 Apr 5, 2023
707152c
add test module
xzyu02 Apr 8, 2023
02bd955
Merge branch 'PythonOT:master' into demd
xzyu02 Apr 8, 2023
81ab727
add contributors
xzyu02 Apr 8, 2023
706d6a5
pass pyflake checks, pass pep8
xzyu02 Apr 8, 2023
4e6f693
added the PR to the RELEASES.md file
xzyu02 Apr 8, 2023
74c87dc
merge from master
xzyu02 Apr 20, 2023
6730631
Merge branch 'master' into demd
rflamary Apr 24, 2023
08bb919
temporal changes with logs
xzyu02 May 4, 2023
29d16f4
init changes
xzyu02 May 7, 2023
4226eee
merge examples, demd -> lp.dmmot
xzyu02 May 7, 2023
3c7ab34
bug fix in plot_dmmot, some commenting/documenting edits
ronakrm May 17, 2023
7452379
dmmot example cleanup, some comments/plotting edits
ronakrm May 17, 2023
9c360bb
add dist_monge method
xzyu02 Jun 1, 2023
21f16c5
merge from incoming
xzyu02 Jun 1, 2023
697036d
all dmmot methods takes (n, d) shape A as input (follows POT style)
xzyu02 Jun 1, 2023
70326a6
passed pep8 and pyflake checks
xzyu02 Jun 1, 2023
b4b4609
merge from master
xzyu02 Jun 1, 2023
be09209
Merge branch 'master' into demd
rflamary Jun 12, 2023
8d16b0f
Merge branch 'master' into demd
xzyu02 Jun 12, 2023
6de193c
resolve test fail issue
xzyu02 Jun 12, 2023
e98c7ee
fix pep8 error
xzyu02 Jun 13, 2023
7339e8a
resolve issues from last review, pyflake and pep8 checked
xzyu02 Jul 5, 2023
fd444b7
add lr decay
xzyu02 Jul 7, 2023
bd2d2ec
Merge branch 'master' into demd
rflamary Jul 10, 2023
f531b9e
add more examples, ground cost options, test for uniqueness
xzyu02 Jul 26, 2023
c1ccd46
Merge branch 'demd' of github.com:x12hengyu/POT into demd
xzyu02 Jul 26, 2023
99d2e86
Merge branch 'master' into demd
xzyu02 Jul 26, 2023
b3cb896
remove additional experiment setting, not needed in this PR
xzyu02 Jul 28, 2023
2d22fc9
fixed line 14 1 blank line
xzyu02 Jul 29, 2023
018313b
Merge branch 'master' into demd
rflamary Aug 2, 2023
a7bde66
fix gradient computation link
xzyu02 Aug 2, 2023
b370202
Merge branch 'demd' of github.com:x12hengyu/POT into demd
xzyu02 Aug 2, 2023
24a69c0
Update ot/lp/dmmot.py
rflamary Aug 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ The contributors to this library are:
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein)
* [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
* [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
* [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers)

## Acknowledgments
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ POT provides the following generic OT solvers (links to examples):
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]).
* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.

POT provides the following Machine Learning related solvers:
Expand Down Expand Up @@ -319,3 +320,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022). [Template based graph neural network with optimal transport distances](https://papers.nips.cc/paper_files/paper/2022/file/4d3525bc60ba1adc72336c0392d3d902-Paper-Conference.pdf). Advances in Neural Information Processing Systems, 35.

[54] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks](https://arxiv.org/pdf/2006.04804).

[55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations (ICLR).

[56] Jeffery Kline. [Properties of the d-dimensional earth mover’s problem](https://www.sciencedirect.com/science/article/pii/S0166218X19301441). Discrete Applied Mathematics, 265: 128–141, 2019.
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
- Make marginal parameters optional for (F)GW solvers in `._gw`, `._bregman` and `._semirelaxed` (PR #455)
- Add Entropic Wasserstein Component Analysis (ECWA) in ot.dr (PR #486)

- Added feature Efficient Discrete Multi Marginal Optimal Transport Regularization + examples (PR #454)

#### Closed issues

- Fix change in scipy API for `cdist` (PR #487)
Expand Down
158 changes: 158 additions & 0 deletions examples/others/plot_dmmot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# -*- coding: utf-8 -*-
r"""
===============================================================================
Computing d-dimensional Barycenters via d-MMOT
===============================================================================

When the cost is discretized (Monge), the d-MMOT solver can more quickly
compute and minimize the distance between many distributions without the need
for intermediate barycenter computations. This example compares the time to
identify, and the quality of, solutions for the d-MMOT problem using a
primal/dual algorithm and classical LP barycenter approaches.
"""

# Author: Ronak Mehta <ronakrm@cs.wisc.edu>
# Xizheng Yu <xyu354@wisc.edu>
#
# License: MIT License

# %%
# Generating 2 distributions
# -----
import numpy as np
import matplotlib.pyplot as pl
import ot

np.random.seed(0)

n = 100
d = 2
# Gaussian distributions
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m=mean, s=std
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
A = np.vstack((a1, a2)).T
x = np.arange(n, dtype=np.float64)
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski')

pl.figure(1, figsize=(6.4, 3))
pl.plot(x, a1, 'b', label='Source distribution')
pl.plot(x, a2, 'r', label='Target distribution')
pl.legend()

# %%
# Minimize the distances among distributions, identify the Barycenter
# -----
# The objective being minimized is different for both methods, so the objective
# values cannot be compared.

# L2 Iteration
weights = np.ones(d) / d
l2_bary = A.dot(weights)

print('LP Iterations:')
weights = np.ones(d) / d
lp_bary, lp_log = ot.lp.barycenter(
A, M, weights, solver='interior-point', verbose=False, log=True)
print('Time\t: ', ot.toc(''))
print('Obj\t: ', lp_log['fun'])

print('')
print('Discrete MMOT Algorithm:')
ot.tic()
barys, log = ot.lp.dmmot_monge_1dgrid_optimize(
A, niters=4000, lr_init=1e-5, lr_decay=0.997, log=True)
dmmot_obj = log['primal objective']
print('Time\t: ', ot.toc(''))
print('Obj\t: ', dmmot_obj)

# %%
# Compare Barycenters in both methods
# -----
pl.figure(1, figsize=(6.4, 3))
for i in range(len(barys)):
if i == 0:
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT')
else:
continue
# pl.plot(x, barys[i], 'g-*')
pl.plot(x, lp_bary, label='LP Barycenter')
pl.plot(x, l2_bary, label='L2 Barycenter')
pl.plot(x, a1, 'b', label='Source distribution')
pl.plot(x, a2, 'r', label='Target distribution')
pl.title('Monge Cost: Barycenters from LP Solver and dmmot solver')
pl.legend()


# %%
# More than 2 distributions
# --------------------------------------------------
# Generate 7 pseudorandom gaussian distributions with 50 bins.
n = 50 # nb bins
d = 7
vecsize = n * d

data = []
for i in range(d):
m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1)
a = ot.datasets.make_1D_gauss(n, m=m, s=5)
data.append(a)

x = np.arange(n, dtype=np.float64)
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski')
A = np.vstack(data).T

pl.figure(1, figsize=(6.4, 3))
for i in range(len(data)):
pl.plot(x, data[i])

pl.title('Distributions')
pl.legend()

# %%
# Minimizing Distances Among Many Distributions
# ---------------
# The objective being minimized is different for both methods, so the objective
# values cannot be compared.

# Perform gradient descent optimization using the d-MMOT method.
barys = ot.lp.dmmot_monge_1dgrid_optimize(
A, niters=3000, lr_init=1e-4, lr_decay=0.997)

# after minimization, any distribution can be used as a estimate of barycenter.
bary = barys[0]

# Compute 1D Wasserstein barycenter using the L2/LP method
weights = ot.unif(d)
l2_bary = A.dot(weights)
lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point',
verbose=False, log=True)

# %%
# Compare Barycenters in both methods
# ---------
pl.figure(1, figsize=(6.4, 3))
pl.plot(x, bary, 'g-*', label='Discrete MMOT')
pl.plot(x, l2_bary, 'k', label='L2 Barycenter')
pl.plot(x, lp_bary, 'k-', label='LP Wasserstein')
pl.title('Barycenters')
pl.legend()

# %%
# Compare with original distributions
# ---------
pl.figure(1, figsize=(6.4, 3))
for i in range(len(data)):
pl.plot(x, data[i])
for i in range(len(barys)):
if i == 0:
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT')
else:
continue
# pl.plot(x, barys[i], 'g')
pl.plot(x, l2_bary, 'k^', label='L2')
pl.plot(x, lp_bary, 'o', color='grey', label='LP')
pl.title('Barycenters')
pl.legend()
pl.show()

# %%
4 changes: 3 additions & 1 deletion ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from . import cvx
from .cvx import barycenter
from .dmmot import *

# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
Expand All @@ -30,7 +31,8 @@

__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter',
'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle']
'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle',
'discrete_mmot', 'discrete_mmot_converge']


def check_number_threads(numThreads):
Expand Down
Loading