@@ -570,20 +570,108 @@ def test_entropic_gromov_dtype_device(nx):
570
570
571
571
C1b , C2b , pb , qb = nx .from_numpy (C1 , C2 , p , q , type_as = tp )
572
572
573
- for solver in ['PGD' , 'PPA' ]:
574
- Gb = ot .gromov .entropic_gromov_wasserstein (
575
- C1b , C2b , pb , qb , 'square_loss' , epsilon = 1e-1 , max_iter = 5 ,
576
- solver = solver , verbose = True
577
- )
578
- gw_valb = ot .gromov .entropic_gromov_wasserstein2 (
579
- C1b , C2b , pb , qb , 'square_loss' , epsilon = 1e-1 , max_iter = 5 ,
580
- solver = solver , verbose = True
581
- )
573
+ for solver in ['PGD' , 'PPA' , 'BAPG' ]:
574
+ if solver == 'BAPG' :
575
+ Gb = ot .gromov .entropic_BAPG_gromov_wasserstein (
576
+ C1b , C2b , pb , qb , max_iter = 2 , verbose = True )
577
+ gw_valb = ot .gromov .entropic_BAPG_gromov_wasserstein2 (
578
+ C1b , C2b , pb , qb , max_iter = 2 , verbose = True )
579
+ else :
580
+ Gb = ot .gromov .entropic_gromov_wasserstein (
581
+ C1b , C2b , pb , qb , max_iter = 2 , solver = solver , verbose = True )
582
+ gw_valb = ot .gromov .entropic_gromov_wasserstein2 (
583
+ C1b , C2b , pb , qb , max_iter = 2 , solver = solver , verbose = True )
582
584
583
585
nx .assert_same_dtype_device (C1b , Gb )
584
586
nx .assert_same_dtype_device (C1b , gw_valb )
585
587
586
588
589
+ def test_entropic_BAPG_gromov (nx ):
590
+ n_samples = 10 # nb samples
591
+
592
+ mu_s = np .array ([0 , 0 ])
593
+ cov_s = np .array ([[1 , 0 ], [0 , 1 ]])
594
+
595
+ xs = ot .datasets .make_2D_samples_gauss (n_samples , mu_s , cov_s , random_state = 42 )
596
+
597
+ xt = xs [::- 1 ].copy ()
598
+
599
+ p = ot .unif (n_samples )
600
+ q = ot .unif (n_samples )
601
+ G0 = p [:, None ] * q [None , :]
602
+ C1 = ot .dist (xs , xs )
603
+ C2 = ot .dist (xt , xt )
604
+
605
+ C1 /= C1 .max ()
606
+ C2 /= C2 .max ()
607
+
608
+ C1b , C2b , pb , qb , G0b = nx .from_numpy (C1 , C2 , p , q , G0 )
609
+
610
+ # complete test with marginal loss = True
611
+ marginal_loss = True
612
+ with pytest .raises (ValueError ):
613
+ loss_fun = 'weird_loss_fun'
614
+ G , log = ot .gromov .entropic_BAPG_gromov_wasserstein (
615
+ C1 , C2 , None , q , loss_fun , symmetric = None , G0 = G0 ,
616
+ epsilon = 1e-1 , max_iter = 10 , marginal_loss = marginal_loss ,
617
+ verbose = True , log = True )
618
+
619
+ G , log = ot .gromov .entropic_BAPG_gromov_wasserstein (
620
+ C1 , C2 , None , q , 'square_loss' , symmetric = None , G0 = G0 ,
621
+ epsilon = 1e-1 , max_iter = 10 , marginal_loss = marginal_loss ,
622
+ verbose = True , log = True )
623
+ Gb = nx .to_numpy (ot .gromov .entropic_BAPG_gromov_wasserstein (
624
+ C1b , C2b , pb , None , 'square_loss' , symmetric = True , G0 = None ,
625
+ epsilon = 1e-1 , max_iter = 10 , marginal_loss = marginal_loss , verbose = True ,
626
+ log = False
627
+ ))
628
+
629
+ # check constraints
630
+ np .testing .assert_allclose (G , Gb , atol = 1e-06 )
631
+ np .testing .assert_allclose (
632
+ p , Gb .sum (1 ), atol = 1e-02 ) # cf convergence gromov
633
+ np .testing .assert_allclose (
634
+ q , Gb .sum (0 ), atol = 1e-02 ) # cf convergence gromov
635
+
636
+ with pytest .warns (UserWarning ):
637
+
638
+ gw = ot .gromov .entropic_BAPG_gromov_wasserstein2 (
639
+ C1 , C2 , p , q , 'kl_loss' , symmetric = False , G0 = None ,
640
+ max_iter = 10 , epsilon = 1e-2 , marginal_loss = marginal_loss , log = False )
641
+
642
+ gw , log = ot .gromov .entropic_BAPG_gromov_wasserstein2 (
643
+ C1 , C2 , p , q , 'kl_loss' , symmetric = False , G0 = None ,
644
+ max_iter = 10 , epsilon = 1. , marginal_loss = marginal_loss , log = True )
645
+ gwb , logb = ot .gromov .entropic_BAPG_gromov_wasserstein2 (
646
+ C1b , C2b , pb , qb , 'kl_loss' , symmetric = None , G0 = G0b ,
647
+ max_iter = 10 , epsilon = 1. , marginal_loss = marginal_loss , log = True )
648
+ gwb = nx .to_numpy (gwb )
649
+
650
+ G = log ['T' ]
651
+ Gb = nx .to_numpy (logb ['T' ])
652
+
653
+ np .testing .assert_allclose (gw , gwb , atol = 1e-06 )
654
+ np .testing .assert_allclose (gw , 0 , atol = 1e-1 , rtol = 1e-1 )
655
+
656
+ # check constraints
657
+ np .testing .assert_allclose (G , Gb , atol = 1e-06 )
658
+ np .testing .assert_allclose (
659
+ p , Gb .sum (1 ), atol = 1e-02 ) # cf convergence gromov
660
+ np .testing .assert_allclose (
661
+ q , Gb .sum (0 ), atol = 1e-02 ) # cf convergence gromov
662
+
663
+ marginal_loss = False
664
+ G , log = ot .gromov .entropic_BAPG_gromov_wasserstein (
665
+ C1 , C2 , None , q , 'square_loss' , symmetric = None , G0 = G0 ,
666
+ epsilon = 1e-1 , max_iter = 10 , marginal_loss = marginal_loss ,
667
+ verbose = True , log = True )
668
+ Gb = nx .to_numpy (ot .gromov .entropic_BAPG_gromov_wasserstein (
669
+ C1b , C2b , pb , None , 'square_loss' , symmetric = False , G0 = None ,
670
+ epsilon = 1e-1 , max_iter = 10 , marginal_loss = marginal_loss , verbose = True ,
671
+ log = False
672
+ ))
673
+
674
+
587
675
@pytest .skip_backend ("tf" , reason = "test very slow with tf backend" )
588
676
def test_entropic_fgw (nx ):
589
677
n_samples = 5 # nb samples
@@ -722,6 +810,99 @@ def test_entropic_proximal_fgw(nx):
722
810
q , Gb .sum (0 ), atol = 1e-04 ) # cf convergence gromov
723
811
724
812
813
+ def test_entropic_BAPG_fgw (nx ):
814
+ n_samples = 5 # nb samples
815
+
816
+ mu_s = np .array ([0 , 0 ])
817
+ cov_s = np .array ([[1 , 0 ], [0 , 1 ]])
818
+
819
+ xs = ot .datasets .make_2D_samples_gauss (n_samples , mu_s , cov_s , random_state = 42 )
820
+
821
+ xt = xs [::- 1 ].copy ()
822
+
823
+ rng = np .random .RandomState (42 )
824
+ ys = rng .randn (xs .shape [0 ], 2 )
825
+ yt = ys [::- 1 ].copy ()
826
+
827
+ p = ot .unif (n_samples )
828
+ q = ot .unif (n_samples )
829
+ G0 = p [:, None ] * q [None , :]
830
+
831
+ C1 = ot .dist (xs , xs )
832
+ C2 = ot .dist (xt , xt )
833
+
834
+ C1 /= C1 .max ()
835
+ C2 /= C2 .max ()
836
+
837
+ M = ot .dist (ys , yt )
838
+
839
+ Mb , C1b , C2b , pb , qb , G0b = nx .from_numpy (M , C1 , C2 , p , q , G0 )
840
+
841
+ with pytest .raises (ValueError ):
842
+ loss_fun = 'weird_loss_fun'
843
+ G , log = ot .gromov .entropic_BAPG_fused_gromov_wasserstein (
844
+ M , C1 , C2 , p , q , loss_fun = loss_fun , max_iter = 1 , log = True )
845
+
846
+ # complete test with marginal loss = True
847
+ marginal_loss = True
848
+
849
+ G , log = ot .gromov .entropic_BAPG_fused_gromov_wasserstein (
850
+ M , C1 , C2 , p , q , 'square_loss' , symmetric = None , G0 = G0 ,
851
+ epsilon = 1e-1 , max_iter = 10 , marginal_loss = marginal_loss , log = True )
852
+ Gb = nx .to_numpy (ot .gromov .entropic_BAPG_fused_gromov_wasserstein (
853
+ Mb , C1b , C2b , pb , qb , 'square_loss' , symmetric = True , G0 = None ,
854
+ epsilon = 1e-1 , max_iter = 10 , marginal_loss = marginal_loss , verbose = True ))
855
+
856
+ # check constraints
857
+ np .testing .assert_allclose (G , Gb , atol = 1e-06 )
858
+ np .testing .assert_allclose (
859
+ p , Gb .sum (1 ), atol = 1e-02 ) # cf convergence gromov
860
+ np .testing .assert_allclose (
861
+ q , Gb .sum (0 ), atol = 1e-02 ) # cf convergence gromov
862
+
863
+ with pytest .warns (UserWarning ):
864
+
865
+ fgw = ot .gromov .entropic_BAPG_fused_gromov_wasserstein2 (
866
+ M , C1 , C2 , p , q , 'kl_loss' , symmetric = False , G0 = None ,
867
+ max_iter = 10 , epsilon = 1e-3 , marginal_loss = marginal_loss , log = False )
868
+
869
+ fgw , log = ot .gromov .entropic_BAPG_fused_gromov_wasserstein2 (
870
+ M , C1 , C2 , p , None , 'kl_loss' , symmetric = True , G0 = None ,
871
+ max_iter = 5 , epsilon = 1 , marginal_loss = marginal_loss , log = True )
872
+ fgwb , logb = ot .gromov .entropic_BAPG_fused_gromov_wasserstein2 (
873
+ Mb , C1b , C2b , None , qb , 'kl_loss' , symmetric = None , G0 = G0b ,
874
+ max_iter = 5 , epsilon = 1 , marginal_loss = marginal_loss , log = True )
875
+ fgwb = nx .to_numpy (fgwb )
876
+
877
+ G = log ['T' ]
878
+ Gb = nx .to_numpy (logb ['T' ])
879
+
880
+ np .testing .assert_allclose (fgw , fgwb , atol = 1e-06 )
881
+ np .testing .assert_allclose (fgw , 0 , atol = 1e-1 , rtol = 1e-1 )
882
+
883
+ # check constraints
884
+ np .testing .assert_allclose (G , Gb , atol = 1e-06 )
885
+ np .testing .assert_allclose (
886
+ p , Gb .sum (1 ), atol = 1e-02 ) # cf convergence gromov
887
+ np .testing .assert_allclose (
888
+ q , Gb .sum (0 ), atol = 1e-02 ) # cf convergence gromov
889
+
890
+ # Tests with marginal_loss = False
891
+ marginal_loss = False
892
+ G , log = ot .gromov .entropic_BAPG_fused_gromov_wasserstein (
893
+ M , C1 , C2 , p , q , 'square_loss' , symmetric = False , G0 = G0 ,
894
+ epsilon = 1e-1 , max_iter = 10 , marginal_loss = marginal_loss , log = True )
895
+ Gb = nx .to_numpy (ot .gromov .entropic_BAPG_fused_gromov_wasserstein (
896
+ Mb , C1b , C2b , pb , qb , 'square_loss' , symmetric = None , G0 = None ,
897
+ epsilon = 1e-1 , max_iter = 10 , marginal_loss = marginal_loss , verbose = True ))
898
+ # check constraints
899
+ np .testing .assert_allclose (G , Gb , atol = 1e-06 )
900
+ np .testing .assert_allclose (
901
+ p , Gb .sum (1 ), atol = 1e-02 ) # cf convergence gromov
902
+ np .testing .assert_allclose (
903
+ q , Gb .sum (0 ), atol = 1e-02 ) # cf convergence gromov
904
+
905
+
725
906
def test_asymmetric_entropic_fgw (nx ):
726
907
n_samples = 5 # nb samples
727
908
rng = np .random .RandomState (0 )
@@ -797,15 +978,18 @@ def test_entropic_fgw_dtype_device(nx):
797
978
798
979
Mb , C1b , C2b , pb , qb = nx .from_numpy (M , C1 , C2 , p , q , type_as = tp )
799
980
800
- for solver in ['PGD' , 'PPA' ]:
801
- Gb = ot .gromov .entropic_fused_gromov_wasserstein (
802
- Mb , C1b , C2b , pb , qb , 'square_loss' , epsilon = 0.1 , max_iter = 5 ,
803
- solver = solver , verbose = True
804
- )
805
- fgw_valb = ot .gromov .entropic_fused_gromov_wasserstein2 (
806
- Mb , C1b , C2b , pb , qb , 'square_loss' , epsilon = 0.1 , max_iter = 5 ,
807
- solver = solver , verbose = True
808
- )
981
+ for solver in ['PGD' , 'PPA' , 'BAPG' ]:
982
+ if solver == 'BAPG' :
983
+ Gb = ot .gromov .entropic_BAPG_fused_gromov_wasserstein (
984
+ Mb , C1b , C2b , pb , qb , max_iter = 2 )
985
+ fgw_valb = ot .gromov .entropic_BAPG_fused_gromov_wasserstein2 (
986
+ Mb , C1b , C2b , pb , qb , max_iter = 2 )
987
+
988
+ else :
989
+ Gb = ot .gromov .entropic_fused_gromov_wasserstein (
990
+ Mb , C1b , C2b , pb , qb , max_iter = 2 , solver = solver )
991
+ fgw_valb = ot .gromov .entropic_fused_gromov_wasserstein2 (
992
+ Mb , C1b , C2b , pb , qb , max_iter = 2 , solver = solver )
809
993
810
994
nx .assert_same_dtype_device (C1b , Gb )
811
995
nx .assert_same_dtype_device (C1b , fgw_valb )
0 commit comments