Skip to content

Commit bbd8f20

Browse files
authored
Merge pull request #108 from kilianFatras/master
[MRG] Fix log and nbiter bug in gromov_wasserstein and gromov_wasserstein2
2 parents 3635fc4 + 0280a34 commit bbd8f20

File tree

4 files changed

+118
-81
lines changed

4 files changed

+118
-81
lines changed

ot/gromov.py

Lines changed: 77 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,6 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs
276276
- p : distribution in the source space
277277
- q : distribution in the target space
278278
- L : loss function to account for the misfit between the similarity matrices
279-
- H : entropy
280279
281280
Parameters
282281
----------
@@ -343,6 +342,83 @@ def df(G):
343342
return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
344343

345344

345+
def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
346+
"""
347+
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
348+
349+
The function solves the following optimization problem:
350+
351+
.. math::
352+
GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
353+
354+
Where :
355+
- C1 : Metric cost matrix in the source space
356+
- C2 : Metric cost matrix in the target space
357+
- p : distribution in the source space
358+
- q : distribution in the target space
359+
- L : loss function to account for the misfit between the similarity matrices
360+
361+
Parameters
362+
----------
363+
C1 : ndarray, shape (ns, ns)
364+
Metric cost matrix in the source space
365+
C2 : ndarray, shape (nt, nt)
366+
Metric cost matrix in the target space
367+
p : ndarray, shape (ns,)
368+
Distribution in the source space.
369+
q : ndarray, shape (nt,)
370+
Distribution in the target space.
371+
loss_fun : str
372+
loss function used for the solver either 'square_loss' or 'kl_loss'
373+
max_iter : int, optional
374+
Max number of iterations
375+
tol : float, optional
376+
Stop threshold on error (>0)
377+
verbose : bool, optional
378+
Print information along iterations
379+
log : bool, optional
380+
record log if True
381+
armijo : bool, optional
382+
If True the steps of the line-search is found via an armijo research. Else closed form is used.
383+
If there is convergence issues use False.
384+
385+
Returns
386+
-------
387+
gw_dist : float
388+
Gromov-Wasserstein distance
389+
log : dict
390+
convergence information and Coupling marix
391+
392+
References
393+
----------
394+
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
395+
"Gromov-Wasserstein averaging of kernel and distance matrices."
396+
International Conference on Machine Learning (ICML). 2016.
397+
398+
.. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
399+
metric approach to object matching. Foundations of computational
400+
mathematics 11.4 (2011): 417-487.
401+
402+
"""
403+
404+
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
405+
406+
G0 = p[:, None] * q[None, :]
407+
408+
def f(G):
409+
return gwloss(constC, hC1, hC2, G)
410+
411+
def df(G):
412+
return gwggrad(constC, hC1, hC2, G)
413+
res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
414+
log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res)
415+
log_gw['T'] = res
416+
if log:
417+
return log_gw['gw_dist'], log_gw
418+
else:
419+
return log_gw['gw_dist']
420+
421+
346422
def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs):
347423
"""
348424
Computes the FGW transport between two graphs see [24]
@@ -506,84 +582,6 @@ def df(G):
506582
return log['fgw_dist']
507583

508584

509-
def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs):
510-
"""
511-
Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q)
512-
513-
The function solves the following optimization problem:
514-
515-
.. math::
516-
GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}
517-
518-
Where :
519-
- C1 : Metric cost matrix in the source space
520-
- C2 : Metric cost matrix in the target space
521-
- p : distribution in the source space
522-
- q : distribution in the target space
523-
- L : loss function to account for the misfit between the similarity matrices
524-
- H : entropy
525-
526-
Parameters
527-
----------
528-
C1 : ndarray, shape (ns, ns)
529-
Metric cost matrix in the source space
530-
C2 : ndarray, shape (nt, nt)
531-
Metric cost matrix in the target space
532-
p : ndarray, shape (ns,)
533-
Distribution in the source space.
534-
q : ndarray, shape (nt,)
535-
Distribution in the target space.
536-
loss_fun : str
537-
loss function used for the solver either 'square_loss' or 'kl_loss'
538-
max_iter : int, optional
539-
Max number of iterations
540-
tol : float, optional
541-
Stop threshold on error (>0)
542-
verbose : bool, optional
543-
Print information along iterations
544-
log : bool, optional
545-
record log if True
546-
armijo : bool, optional
547-
If True the steps of the line-search is found via an armijo research. Else closed form is used.
548-
If there is convergence issues use False.
549-
550-
Returns
551-
-------
552-
gw_dist : float
553-
Gromov-Wasserstein distance
554-
log : dict
555-
convergence information and Coupling marix
556-
557-
References
558-
----------
559-
.. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon,
560-
"Gromov-Wasserstein averaging of kernel and distance matrices."
561-
International Conference on Machine Learning (ICML). 2016.
562-
563-
.. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the
564-
metric approach to object matching. Foundations of computational
565-
mathematics 11.4 (2011): 417-487.
566-
567-
"""
568-
569-
constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun)
570-
571-
G0 = p[:, None] * q[None, :]
572-
573-
def f(G):
574-
return gwloss(constC, hC1, hC2, G)
575-
576-
def df(G):
577-
return gwggrad(constC, hC1, hC2, G)
578-
res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs)
579-
log['gw_dist'] = gwloss(constC, hC1, hC2, res)
580-
log['T'] = res
581-
if log:
582-
return log['gw_dist'], log
583-
else:
584-
return log['gw_dist']
585-
586-
587585
def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon,
588586
max_iter=1000, tol=1e-9, verbose=False, log=False):
589587
"""

ot/optim.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def solve_linesearch(cost, G, deltaG, Mi, f_val,
134134
return alpha, fc, f_val
135135

136136

137-
def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
137+
def cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000,
138138
stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
139139
"""
140140
Solve the general regularized OT problem with conditional gradient
@@ -172,6 +172,8 @@ def cg(a, b, M, reg, f, df, G0=None, numItermax=200,
172172
initial guess (default is indep joint density)
173173
numItermax : int, optional
174174
Max number of iterations
175+
numItermaxEmd : int, optional
176+
Max number of iterations for emd
175177
stopThr : float, optional
176178
Stop threshol on the relative variation (>0)
177179
stopThr2 : float, optional
@@ -238,7 +240,7 @@ def cost(G):
238240
Mi += Mi.min()
239241

240242
# solve linear program
241-
Gc = emd(a, b, Mi)
243+
Gc = emd(a, b, Mi, numItermax=numItermaxEmd)
242244

243245
deltaG = Gc - G
244246

test/test_gromov.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,14 @@ def test_gromov():
4444

4545
gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True)
4646

47+
gw_val = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=False)
48+
4749
G = log['T']
4850

4951
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
5052

53+
np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False
54+
5155
# check constratints
5256
np.testing.assert_allclose(
5357
p, G.sum(1), atol=1e-04) # cf convergence gromov

test/test_optim.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,39 @@ def df(G):
3737
np.testing.assert_allclose(b, G.sum(0))
3838

3939

40+
def test_conditional_gradient2():
41+
n = 4000 # nb samples
42+
43+
mu_s = np.array([0, 0])
44+
cov_s = np.array([[1, 0], [0, 1]])
45+
46+
mu_t = np.array([4, 4])
47+
cov_t = np.array([[1, -.8], [-.8, 1]])
48+
49+
xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
50+
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
51+
52+
a, b = np.ones((n,)) / n, np.ones((n,)) / n
53+
54+
# loss matrix
55+
M = ot.dist(xs, xt)
56+
M /= M.max()
57+
58+
def f(G):
59+
return 0.5 * np.sum(G**2)
60+
61+
def df(G):
62+
return G
63+
64+
reg = 1e-1
65+
66+
G, log = ot.optim.cg(a, b, M, reg, f, df, numItermaxEmd=200000,
67+
verbose=True, log=True)
68+
69+
np.testing.assert_allclose(a, G.sum(1))
70+
np.testing.assert_allclose(b, G.sum(0))
71+
72+
4073
def test_generalized_conditional_gradient():
4174

4275
n_bins = 100 # nb bins

0 commit comments

Comments
 (0)