@@ -490,6 +490,43 @@ def test_barycenter(nx, method, verbose, warn):
490
490
ot .bregman .barycenter (A_nx , M_nx , reg , log = True )
491
491
492
492
493
+
494
+ @pytest .mark .parametrize ("method, verbose, warn" ,
495
+ product (["sinkhorn" , "sinkhorn_stabilized" , "sinkhorn_log" ],
496
+ [True , False ], [True , False ]))
497
+ def test_barycenter_assymetric_cost (nx , method , verbose , warn ):
498
+ n_bins = 20 # nb bins
499
+
500
+ # Gaussian distributions
501
+ A = ot .datasets .make_1D_gauss (n_bins , m = 30 , s = 10 ) # m= mean, s= std
502
+
503
+ # creating matrix A containing all distributions
504
+ A = A [:, None ]
505
+
506
+ # assymetric loss matrix + normalization
507
+ rng = np .random .RandomState (42 )
508
+ M = rng .randn (n_bins , n_bins ) ** 2
509
+ M /= M .max ()
510
+
511
+ A_nx , M_nx = nx .from_numpy (A , M )
512
+ reg = 1e-2
513
+
514
+ if nx .__name__ in ("jax" , "tf" ) and method == "sinkhorn_log" :
515
+ with pytest .raises (NotImplementedError ):
516
+ ot .bregman .barycenter (A_nx , M_nx , reg , method = method )
517
+ else :
518
+ # wasserstein
519
+ bary_wass_np = ot .bregman .barycenter (A , M , reg , method = method , verbose = verbose , warn = warn )
520
+ bary_wass , _ = ot .bregman .barycenter (A_nx , M_nx , reg , method = method , log = True )
521
+ bary_wass = nx .to_numpy (bary_wass )
522
+
523
+ np .testing .assert_allclose (1 , np .sum (bary_wass ))
524
+ np .testing .assert_allclose (bary_wass , bary_wass_np )
525
+
526
+ ot .bregman .barycenter (A_nx , M_nx , reg , log = True )
527
+
528
+
529
+
493
530
@pytest .mark .parametrize ("method, verbose, warn" ,
494
531
product (["sinkhorn" , "sinkhorn_log" ],
495
532
[True , False ], [True , False ]))
0 commit comments