Skip to content

Commit 6381c63

Browse files
add tests
1 parent 4d80644 commit 6381c63

File tree

4 files changed

+216
-23
lines changed

4 files changed

+216
-23
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
+ Wrapper for `geomloss`` solver on empirical samples (PR #571)
2121
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
2222
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)
23+
+ Add new entropic BAPG solvers for GW and FGW (PR #581)
2324

2425
#### Closed issues
2526
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)

ot/gromov/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,13 @@
2020

2121
from ._bregman import (entropic_gromov_wasserstein,
2222
entropic_gromov_wasserstein2,
23+
entropic_BAPG_gromov_wasserstein,
24+
entropic_BAPG_gromov_wasserstein2,
2325
entropic_gromov_barycenters,
2426
entropic_fused_gromov_wasserstein,
2527
entropic_fused_gromov_wasserstein2,
28+
entropic_BAPG_fused_gromov_wasserstein,
29+
entropic_BAPG_fused_gromov_wasserstein2,
2630
entropic_fused_gromov_barycenters)
2731

2832
from ._estimators import (GW_distance_estimation, pointwise_gromov_wasserstein,
@@ -49,8 +53,10 @@
4953
'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
5054
'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
5155
'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
56+
'entropic_BAPG_gromov_wasserstein', 'entropic_BAPG_gromov_wasserstein2',
5257
'entropic_gromov_barycenters', 'entropic_fused_gromov_wasserstein',
53-
'entropic_fused_gromov_wasserstein2', 'entropic_fused_gromov_barycenters',
58+
'entropic_fused_gromov_wasserstein2', 'entropic_BAPG_fused_gromov_wasserstein',
59+
'entropic_BAPG_fused_gromov_wasserstein2', 'entropic_fused_gromov_barycenters',
5460
'GW_distance_estimation', 'pointwise_gromov_wasserstein', 'sampled_gromov_wasserstein',
5561
'semirelaxed_gromov_wasserstein', 'semirelaxed_gromov_wasserstein2',
5662
'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2',

ot/gromov/_bregman.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,10 @@ def df(T):
509509

510510
cpt += 1
511511

512-
if abs(nx.sum(T) - 1) > 1e-5:
512+
if nx.any(nx.isnan(T)):
513513
warnings.warn("Solver failed to produce a transport plan. You might "
514-
"want to increase the regularization parameter `epsilon`.")
514+
"want to increase the regularization parameter `epsilon`.",
515+
UserWarning)
515516
if log:
516517
log['gw_dist'] = gwloss(constC, hC1, hC2, T, nx)
517518

@@ -1328,9 +1329,10 @@ def df(T):
13281329

13291330
cpt += 1
13301331

1331-
if abs(nx.sum(T) - 1) > 1e-5:
1332+
if nx.any(nx.isnan(T)):
13321333
warnings.warn("Solver failed to produce a transport plan. You might "
1333-
"want to increase the regularization parameter `epsilon`.")
1334+
"want to increase the regularization parameter `epsilon`.",
1335+
UserWarning)
13341336
if log:
13351337
log['fgw_dist'] = (1 - alpha) * nx.sum(M * T) + alpha * gwloss(constC, hC1, hC2, T, nx)
13361338

test/test_gromov.py

Lines changed: 202 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -570,20 +570,108 @@ def test_entropic_gromov_dtype_device(nx):
570570

571571
C1b, C2b, pb, qb = nx.from_numpy(C1, C2, p, q, type_as=tp)
572572

573-
for solver in ['PGD', 'PPA']:
574-
Gb = ot.gromov.entropic_gromov_wasserstein(
575-
C1b, C2b, pb, qb, 'square_loss', epsilon=1e-1, max_iter=5,
576-
solver=solver, verbose=True
577-
)
578-
gw_valb = ot.gromov.entropic_gromov_wasserstein2(
579-
C1b, C2b, pb, qb, 'square_loss', epsilon=1e-1, max_iter=5,
580-
solver=solver, verbose=True
581-
)
573+
for solver in ['PGD', 'PPA', 'BAPG']:
574+
if solver == 'BAPG':
575+
Gb = ot.gromov.entropic_BAPG_gromov_wasserstein(
576+
C1b, C2b, pb, qb, max_iter=2, verbose=True)
577+
gw_valb = ot.gromov.entropic_BAPG_gromov_wasserstein2(
578+
C1b, C2b, pb, qb, max_iter=2, verbose=True)
579+
else:
580+
Gb = ot.gromov.entropic_gromov_wasserstein(
581+
C1b, C2b, pb, qb, max_iter=2, solver=solver, verbose=True)
582+
gw_valb = ot.gromov.entropic_gromov_wasserstein2(
583+
C1b, C2b, pb, qb, max_iter=2, solver=solver, verbose=True)
582584

583585
nx.assert_same_dtype_device(C1b, Gb)
584586
nx.assert_same_dtype_device(C1b, gw_valb)
585587

586588

589+
def test_entropic_BAPG_gromov(nx):
590+
n_samples = 10 # nb samples
591+
592+
mu_s = np.array([0, 0])
593+
cov_s = np.array([[1, 0], [0, 1]])
594+
595+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
596+
597+
xt = xs[::-1].copy()
598+
599+
p = ot.unif(n_samples)
600+
q = ot.unif(n_samples)
601+
G0 = p[:, None] * q[None, :]
602+
C1 = ot.dist(xs, xs)
603+
C2 = ot.dist(xt, xt)
604+
605+
C1 /= C1.max()
606+
C2 /= C2.max()
607+
608+
C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)
609+
610+
# complete test with marginal loss = True
611+
marginal_loss = True
612+
with pytest.raises(ValueError):
613+
loss_fun = 'weird_loss_fun'
614+
G, log = ot.gromov.entropic_BAPG_gromov_wasserstein(
615+
C1, C2, None, q, loss_fun, symmetric=None, G0=G0,
616+
epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss,
617+
verbose=True, log=True)
618+
619+
G, log = ot.gromov.entropic_BAPG_gromov_wasserstein(
620+
C1, C2, None, q, 'square_loss', symmetric=None, G0=G0,
621+
epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss,
622+
verbose=True, log=True)
623+
Gb = nx.to_numpy(ot.gromov.entropic_BAPG_gromov_wasserstein(
624+
C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=None,
625+
epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True,
626+
log=False
627+
))
628+
629+
# check constraints
630+
np.testing.assert_allclose(G, Gb, atol=1e-06)
631+
np.testing.assert_allclose(
632+
p, Gb.sum(1), atol=1e-02) # cf convergence gromov
633+
np.testing.assert_allclose(
634+
q, Gb.sum(0), atol=1e-02) # cf convergence gromov
635+
636+
with pytest.warns(UserWarning):
637+
638+
gw = ot.gromov.entropic_BAPG_gromov_wasserstein2(
639+
C1, C2, p, q, 'kl_loss', symmetric=False, G0=None,
640+
max_iter=10, epsilon=1e-2, marginal_loss=marginal_loss, log=False)
641+
642+
gw, log = ot.gromov.entropic_BAPG_gromov_wasserstein2(
643+
C1, C2, p, q, 'kl_loss', symmetric=False, G0=None,
644+
max_iter=10, epsilon=1., marginal_loss=marginal_loss, log=True)
645+
gwb, logb = ot.gromov.entropic_BAPG_gromov_wasserstein2(
646+
C1b, C2b, pb, qb, 'kl_loss', symmetric=None, G0=G0b,
647+
max_iter=10, epsilon=1., marginal_loss=marginal_loss, log=True)
648+
gwb = nx.to_numpy(gwb)
649+
650+
G = log['T']
651+
Gb = nx.to_numpy(logb['T'])
652+
653+
np.testing.assert_allclose(gw, gwb, atol=1e-06)
654+
np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1)
655+
656+
# check constraints
657+
np.testing.assert_allclose(G, Gb, atol=1e-06)
658+
np.testing.assert_allclose(
659+
p, Gb.sum(1), atol=1e-02) # cf convergence gromov
660+
np.testing.assert_allclose(
661+
q, Gb.sum(0), atol=1e-02) # cf convergence gromov
662+
663+
marginal_loss = False
664+
G, log = ot.gromov.entropic_BAPG_gromov_wasserstein(
665+
C1, C2, None, q, 'square_loss', symmetric=None, G0=G0,
666+
epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss,
667+
verbose=True, log=True)
668+
Gb = nx.to_numpy(ot.gromov.entropic_BAPG_gromov_wasserstein(
669+
C1b, C2b, pb, None, 'square_loss', symmetric=False, G0=None,
670+
epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True,
671+
log=False
672+
))
673+
674+
587675
@pytest.skip_backend("tf", reason="test very slow with tf backend")
588676
def test_entropic_fgw(nx):
589677
n_samples = 5 # nb samples
@@ -722,6 +810,99 @@ def test_entropic_proximal_fgw(nx):
722810
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
723811

724812

813+
def test_entropic_BAPG_fgw(nx):
814+
n_samples = 5 # nb samples
815+
816+
mu_s = np.array([0, 0])
817+
cov_s = np.array([[1, 0], [0, 1]])
818+
819+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
820+
821+
xt = xs[::-1].copy()
822+
823+
rng = np.random.RandomState(42)
824+
ys = rng.randn(xs.shape[0], 2)
825+
yt = ys[::-1].copy()
826+
827+
p = ot.unif(n_samples)
828+
q = ot.unif(n_samples)
829+
G0 = p[:, None] * q[None, :]
830+
831+
C1 = ot.dist(xs, xs)
832+
C2 = ot.dist(xt, xt)
833+
834+
C1 /= C1.max()
835+
C2 /= C2.max()
836+
837+
M = ot.dist(ys, yt)
838+
839+
Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)
840+
841+
with pytest.raises(ValueError):
842+
loss_fun = 'weird_loss_fun'
843+
G, log = ot.gromov.entropic_BAPG_fused_gromov_wasserstein(
844+
M, C1, C2, p, q, loss_fun=loss_fun, max_iter=1, log=True)
845+
846+
# complete test with marginal loss = True
847+
marginal_loss = True
848+
849+
G, log = ot.gromov.entropic_BAPG_fused_gromov_wasserstein(
850+
M, C1, C2, p, q, 'square_loss', symmetric=None, G0=G0,
851+
epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, log=True)
852+
Gb = nx.to_numpy(ot.gromov.entropic_BAPG_fused_gromov_wasserstein(
853+
Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=True, G0=None,
854+
epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True))
855+
856+
# check constraints
857+
np.testing.assert_allclose(G, Gb, atol=1e-06)
858+
np.testing.assert_allclose(
859+
p, Gb.sum(1), atol=1e-02) # cf convergence gromov
860+
np.testing.assert_allclose(
861+
q, Gb.sum(0), atol=1e-02) # cf convergence gromov
862+
863+
with pytest.warns(UserWarning):
864+
865+
fgw = ot.gromov.entropic_BAPG_fused_gromov_wasserstein2(
866+
M, C1, C2, p, q, 'kl_loss', symmetric=False, G0=None,
867+
max_iter=10, epsilon=1e-3, marginal_loss=marginal_loss, log=False)
868+
869+
fgw, log = ot.gromov.entropic_BAPG_fused_gromov_wasserstein2(
870+
M, C1, C2, p, None, 'kl_loss', symmetric=True, G0=None,
871+
max_iter=5, epsilon=1, marginal_loss=marginal_loss, log=True)
872+
fgwb, logb = ot.gromov.entropic_BAPG_fused_gromov_wasserstein2(
873+
Mb, C1b, C2b, None, qb, 'kl_loss', symmetric=None, G0=G0b,
874+
max_iter=5, epsilon=1, marginal_loss=marginal_loss, log=True)
875+
fgwb = nx.to_numpy(fgwb)
876+
877+
G = log['T']
878+
Gb = nx.to_numpy(logb['T'])
879+
880+
np.testing.assert_allclose(fgw, fgwb, atol=1e-06)
881+
np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1)
882+
883+
# check constraints
884+
np.testing.assert_allclose(G, Gb, atol=1e-06)
885+
np.testing.assert_allclose(
886+
p, Gb.sum(1), atol=1e-02) # cf convergence gromov
887+
np.testing.assert_allclose(
888+
q, Gb.sum(0), atol=1e-02) # cf convergence gromov
889+
890+
# Tests with marginal_loss = False
891+
marginal_loss = False
892+
G, log = ot.gromov.entropic_BAPG_fused_gromov_wasserstein(
893+
M, C1, C2, p, q, 'square_loss', symmetric=False, G0=G0,
894+
epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, log=True)
895+
Gb = nx.to_numpy(ot.gromov.entropic_BAPG_fused_gromov_wasserstein(
896+
Mb, C1b, C2b, pb, qb, 'square_loss', symmetric=None, G0=None,
897+
epsilon=1e-1, max_iter=10, marginal_loss=marginal_loss, verbose=True))
898+
# check constraints
899+
np.testing.assert_allclose(G, Gb, atol=1e-06)
900+
np.testing.assert_allclose(
901+
p, Gb.sum(1), atol=1e-02) # cf convergence gromov
902+
np.testing.assert_allclose(
903+
q, Gb.sum(0), atol=1e-02) # cf convergence gromov
904+
905+
725906
def test_asymmetric_entropic_fgw(nx):
726907
n_samples = 5 # nb samples
727908
rng = np.random.RandomState(0)
@@ -797,15 +978,18 @@ def test_entropic_fgw_dtype_device(nx):
797978

798979
Mb, C1b, C2b, pb, qb = nx.from_numpy(M, C1, C2, p, q, type_as=tp)
799980

800-
for solver in ['PGD', 'PPA']:
801-
Gb = ot.gromov.entropic_fused_gromov_wasserstein(
802-
Mb, C1b, C2b, pb, qb, 'square_loss', epsilon=0.1, max_iter=5,
803-
solver=solver, verbose=True
804-
)
805-
fgw_valb = ot.gromov.entropic_fused_gromov_wasserstein2(
806-
Mb, C1b, C2b, pb, qb, 'square_loss', epsilon=0.1, max_iter=5,
807-
solver=solver, verbose=True
808-
)
981+
for solver in ['PGD', 'PPA', 'BAPG']:
982+
if solver == 'BAPG':
983+
Gb = ot.gromov.entropic_BAPG_fused_gromov_wasserstein(
984+
Mb, C1b, C2b, pb, qb, max_iter=2)
985+
fgw_valb = ot.gromov.entropic_BAPG_fused_gromov_wasserstein2(
986+
Mb, C1b, C2b, pb, qb, max_iter=2)
987+
988+
else:
989+
Gb = ot.gromov.entropic_fused_gromov_wasserstein(
990+
Mb, C1b, C2b, pb, qb, max_iter=2, solver=solver)
991+
fgw_valb = ot.gromov.entropic_fused_gromov_wasserstein2(
992+
Mb, C1b, C2b, pb, qb, max_iter=2, solver=solver)
809993

810994
nx.assert_same_dtype_device(C1b, Gb)
811995
nx.assert_same_dtype_device(C1b, fgw_valb)

0 commit comments

Comments
 (0)