Skip to content

[MRG] Fix bug in gromov_wasserstein and gromov_wasserstein2 #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 77 additions & 79 deletions ot/gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
- p : distribution in the source space
- q : distribution in the target space
- L : loss function to account for the misfit between the similarity matrices
- H : entropy

Parameters
----------
Expand Down Expand Up @@ -343,6 +342,83 @@ def df(G):
return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)


def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
"""
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)

The function solves the following optimization problem:

.. math::
GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}

Where :
- C1 : Metric cost matrix in the source space
- C2 : Metric cost matrix in the target space
- p : distribution in the source space
- q : distribution in the target space
- L : loss function to account for the misfit between the similarity matrices

Parameters
----------
C1 : ndarray, shape (ns, ns)
Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
Metric cost matrix in the target space
p : ndarray, shape (ns,)
Distribution in the source space.
q : ndarray, shape (nt,)
Distribution in the target space.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
max_iter : int, optional
Max number of iterations
tol : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
armijo : bool, optional
If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.

Returns
-------
gw_dist : float
Gromov-Wasserstein distance
log : dict
convergence information and Coupling marix

References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.

.. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
metric approach to object matching. Foundations of computational
mathematics 11.4 (2011): 417-487.

"""

constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)

G0 = p[:, None] * q[None, :]

def f(G):
return gwloss(constC, hC1, hC2, G)

def df(G):
return gwggrad(constC, hC1, hC2, G)
res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res)
log_gw['T'] = res
if log:
return log_gw['gw_dist'], log_gw
else:
return log_gw['gw_dist']


def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
"""
Computes the FGW transport between two graphs see [24]
Expand Down Expand Up @@ -506,84 +582,6 @@ def df(G):
return log['fgw_dist']


def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
"""
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)

The function solves the following optimization problem:

.. math::
GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}

Where :
- C1 : Metric cost matrix in the source space
- C2 : Metric cost matrix in the target space
- p : distribution in the source space
- q : distribution in the target space
- L : loss function to account for the misfit between the similarity matrices
- H : entropy

Parameters
----------
C1 : ndarray, shape (ns, ns)
Metric cost matrix in the source space
C2 : ndarray, shape (nt, nt)
Metric cost matrix in the target space
p : ndarray, shape (ns,)
Distribution in the source space.
q : ndarray, shape (nt,)
Distribution in the target space.
loss_fun : str
loss function used for the solver either 'square_loss' or 'kl_loss'
max_iter : int, optional
Max number of iterations
tol : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
armijo : bool, optional
If True the steps of the line-search is found via an armijo research. Else closed form is used.
If there is convergence issues use False.

Returns
-------
gw_dist : float
Gromov-Wasserstein distance
log : dict
convergence information and Coupling marix

References
----------
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.

.. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
metric approach to object matching. Foundations of computational
mathematics 11.4 (2011): 417-487.

"""

constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)

G0 = p[:, None] * q[None, :]

def f(G):
return gwloss(constC, hC1, hC2, G)

def df(G):
return gwggrad(constC, hC1, hC2, G)
res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
log['T'] = res
if log:
return log['gw_dist'], log
else:
return log['gw_dist']


def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
max_iter=1000, tol=1e-9, verbose=False, log=False):
"""
Expand Down
6 changes: 4 additions & 2 deletions ot/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
return alpha, fc, f_val


def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
"""
Solve the general regularized OT problem with conditional gradient
Expand Down Expand Up @@ -172,6 +172,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
numItermaxEmd : int, optional
Max number of iterations for emd
stopThr : float, optional
Stop threshol on the relative variation (>0)
stopThr2 : float, optional
Expand Down Expand Up @@ -238,7 +240,7 @@ def cost(G):
Mi += Mi.min()

# solve linear program
Gc = emd(a, b, Mi)
Gc = emd(a, b, Mi, numItermax=numItermaxEmd)

deltaG = Gc - G

Expand Down
4 changes: 4 additions & 0 deletions test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,14 @@ def test_gromov():

gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)

gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)

G = log['T']

np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)

np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False

# check constratints
np.testing.assert_allclose(
p, G.sum(1), atol=1e-04) # cf convergence gromov
Expand Down
33 changes: 33 additions & 0 deletions test/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,39 @@ def df(G):
np.testing.assert_allclose(b, G.sum(0))


def test_conditional_gradient2():
n = 4000 # nb samples

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

mu_t = np.array([4, 4])
cov_t = np.array([[1, -.8], [-.8, 1]])

xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)

a, b = np.ones((n,)) / n, np.ones((n,)) / n

# loss matrix
M = ot.dist(xs, xt)
M /= M.max()

def f(G):
return 0.5 * np.sum(G**2)

def df(G):
return G

reg = 1e-1

G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000,
verbose=True, log=True)

np.testing.assert_allclose(a, G.sum(1))
np.testing.assert_allclose(b, G.sum(0))


def test_generalized_conditional_gradient():

n_bins = 100 # nb bins
Expand Down