From 917b445ad2b77e09944f3d71cc2c1d90d40c63d7 Mon Sep 17 00:00:00 2001 From: clvincen Date: Wed, 8 Nov 2023 11:28:36 +0100 Subject: [PATCH 1/4] correct independence of fgw barycenters to init --- RELEASES.md | 1 + ot/gromov/_gw.py | 27 ++++++++++++++------------- test/test_gromov.py | 38 +++++++++++++++++++++++++++----------- 3 files changed, 42 insertions(+), 24 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 8cf5ae342..6315c9908 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -21,6 +21,7 @@ - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) - Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520) - Handle documentation and warnings when integers are provided to (f)gw solvers based on cg (Issue #530, PR #559) +- Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #564) ## 0.9.1 *August 2023* diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 8ee68e917..a8fff380a 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -1005,13 +1005,15 @@ def fgw_barycenters( else: if init_X is None: X = nx.zeros((N, d), type_as=ps[0]) + else: X = init_X - T = [nx.outer(p, q) for q in ps] - Ms = [dist(X, Ys[s]) for s in range(len(Ys))] + if warmstartT: + T = [nx.outer(p, q) for q in ps] + # removed since 0.9.2 #if loss_fun == 'kl_loss': # armijo = True @@ -1030,11 +1032,19 @@ def fgw_barycenters( Cprev = C Xprev = X + if warmstartT: + T = [fused_gromov_wasserstein( + Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, + G0=T[s], max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)] + else: + T = [fused_gromov_wasserstein( + Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, + G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)] + # T is N,ns if not fixed_features: Ys_temp = [y.T for y in Ys] X = update_feature_matrix(lambdas, Ys_temp, T, p).T - - Ms = [dist(X, Ys[s]) for s in range(len(Ys))] + Ms = [dist(X, Ys[s]) for s in range(len(Ys))] if not fixed_structure: T_temp = [t.T for t in T] @@ -1044,15 +1054,6 @@ def fgw_barycenters( elif loss_fun == 'kl_loss': C = update_kl_loss(p, lambdas, T_temp, Cs) - if warmstartT: - T = [fused_gromov_wasserstein( - Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, - G0=T[s], max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)] - else: - T = [fused_gromov_wasserstein( - Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, - G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)] - # T is N,ns err_feature = nx.norm(X - nx.reshape(Xprev, (N, d))) err_structure = nx.norm(C - Cprev) if log: diff --git a/test/test_gromov.py b/test/test_gromov.py index 9e873d5a0..f7ceec10e 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -1425,11 +1425,18 @@ def test_fgw_barycenter(nx): init_C /= init_C.max() init_Cb = nx.from_numpy(init_C) - Xb, Cb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, - alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, - p=None, loss_fun='square_loss', max_iter=100, tol=1e-3 - ) + try: # to raise warning when `fixed_structure=True`and `init_C=None` + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, + alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False, + p=None, loss_fun='square_loss', max_iter=100, tol=1e-3 + ) + except: + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, + alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, + p=None, loss_fun='square_loss', max_iter=100, tol=1e-3 + ) Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) @@ -1437,12 +1444,21 @@ def test_fgw_barycenter(nx): init_X = rng.randn(n_samples, ys.shape[1]) init_Xb = nx.from_numpy(init_X) - Xb, Cb, logb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_Xb, - p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True - ) + try: # to raise warning when `fixed_features=True`and `init_X=None` + Xb, Cb, logb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=None, + p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True + ) + except: + Xb, Cb, logb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_Xb, + p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True + ) + X, C = nx.to_numpy(Xb), nx.to_numpy(Cb) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) From 85a0d5ba4f5f8454bb1e21778817497651964241 Mon Sep 17 00:00:00 2001 From: clvincen Date: Wed, 8 Nov 2023 11:45:45 +0100 Subject: [PATCH 2/4] fix pep8 and tests --- ot/gromov/_gw.py | 8 ++++---- test/test_gromov.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index a8fff380a..2da7c3819 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -167,7 +167,7 @@ def df(G): return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) # removed since 0.9.2 - #if loss_fun == 'kl_loss': + # if loss_fun == 'kl_loss': # armijo = True # there is no closed form line-search with KL if armijo: @@ -479,7 +479,7 @@ def df(G): return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) # removed since 0.9.2 - #if loss_fun == 'kl_loss': + # if loss_fun == 'kl_loss': # armijo = True # there is no closed form line-search with KL if armijo: @@ -828,7 +828,7 @@ def gromov_barycenters( C = init_C # removed since 0.9.2 - #if loss_fun == 'kl_loss': + # if loss_fun == 'kl_loss': # armijo = True cpt = 0 @@ -1015,7 +1015,7 @@ def fgw_barycenters( T = [nx.outer(p, q) for q in ps] # removed since 0.9.2 - #if loss_fun == 'kl_loss': + # if loss_fun == 'kl_loss': # armijo = True cpt = 0 diff --git a/test/test_gromov.py b/test/test_gromov.py index f7ceec10e..154e8657b 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -1431,7 +1431,7 @@ def test_fgw_barycenter(nx): alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-3 ) - except: + except ot.utils.UndefinedParameter: Xb, Cb = ot.gromov.fgw_barycenters( n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, @@ -1451,7 +1451,7 @@ def test_fgw_barycenter(nx): p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, warmstartT=True, log=True, random_state=98765, verbose=True ) - except: + except ot.utils.UndefinedParameter: Xb, Cb, logb = ot.gromov.fgw_barycenters( n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, fixed_features=True, init_X=init_Xb, From b2760c9192d0ea944a372c9de27469c40df0b38d Mon Sep 17 00:00:00 2001 From: clvincen Date: Wed, 8 Nov 2023 11:48:23 +0100 Subject: [PATCH 3/4] correct PR id --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 6315c9908..c5d63750f 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -21,7 +21,7 @@ - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) - Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520) - Handle documentation and warnings when integers are provided to (f)gw solvers based on cg (Issue #530, PR #559) -- Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #564) +- Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #566) ## 0.9.1 *August 2023* From 882202c6c0f83df56824094fb8db19f90b609814 Mon Sep 17 00:00:00 2001 From: clvincen Date: Wed, 8 Nov 2023 17:29:55 +0100 Subject: [PATCH 4/4] take into account comments --- ot/gromov/_gw.py | 16 ---------------- test/test_gromov.py | 29 ++++++++++++++--------------- 2 files changed, 14 insertions(+), 31 deletions(-) diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 2da7c3819..aba5fa853 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -166,10 +166,6 @@ def df(G): def df(G): return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) - # removed since 0.9.2 - # if loss_fun == 'kl_loss': - # armijo = True # there is no closed form line-search with KL - if armijo: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) @@ -478,10 +474,6 @@ def df(G): def df(G): return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_)) - # removed since 0.9.2 - # if loss_fun == 'kl_loss': - # armijo = True # there is no closed form line-search with KL - if armijo: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs) @@ -827,10 +819,6 @@ def gromov_barycenters( else: C = init_C - # removed since 0.9.2 - # if loss_fun == 'kl_loss': - # armijo = True - cpt = 0 err = 1 @@ -1014,10 +1002,6 @@ def fgw_barycenters( if warmstartT: T = [nx.outer(p, q) for q in ps] - # removed since 0.9.2 - # if loss_fun == 'kl_loss': - # armijo = True - cpt = 0 err_feature = 1 err_structure = 1 diff --git a/test/test_gromov.py b/test/test_gromov.py index 154e8657b..13796c9eb 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -1425,18 +1425,18 @@ def test_fgw_barycenter(nx): init_C /= init_C.max() init_Cb = nx.from_numpy(init_C) - try: # to raise warning when `fixed_structure=True`and `init_C=None` + with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_structure=True`and `init_C=None` Xb, Cb = ot.gromov.fgw_barycenters( n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-3 ) - except ot.utils.UndefinedParameter: - Xb, Cb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, - alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, - p=None, loss_fun='square_loss', max_iter=100, tol=1e-3 - ) + + Xb, Cb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, + alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False, + p=None, loss_fun='square_loss', max_iter=100, tol=1e-3 + ) Xb, Cb = nx.to_numpy(Xb), nx.to_numpy(Cb) np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) np.testing.assert_allclose(Xb.shape, (n_samples, ys.shape[1])) @@ -1444,20 +1444,19 @@ def test_fgw_barycenter(nx): init_X = rng.randn(n_samples, ys.shape[1]) init_Xb = nx.from_numpy(init_X) - try: # to raise warning when `fixed_features=True`and `init_X=None` + with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_features=True`and `init_X=None` Xb, Cb, logb = ot.gromov.fgw_barycenters( n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, fixed_features=True, init_X=None, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, warmstartT=True, log=True, random_state=98765, verbose=True ) - except ot.utils.UndefinedParameter: - Xb, Cb, logb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=True, init_X=init_Xb, - p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, - warmstartT=True, log=True, random_state=98765, verbose=True - ) + Xb, Cb, logb = ot.gromov.fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=True, init_X=init_Xb, + p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, + warmstartT=True, log=True, random_state=98765, verbose=True + ) X, C = nx.to_numpy(Xb), nx.to_numpy(Cb) np.testing.assert_allclose(C.shape, (n_samples, n_samples))