Skip to content

Commit fe49a66

Browse files
author
Hicham Janati
committed
add test for assymetric cost barycenters
1 parent a07687c commit fe49a66

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

test/test_bregman.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,43 @@ def test_barycenter(nx, method, verbose, warn):
490490
ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
491491

492492

493+
494+
@pytest.mark.parametrize("method, verbose, warn",
495+
product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
496+
[True, False], [True, False]))
497+
def test_barycenter_assymetric_cost(nx, method, verbose, warn):
498+
n_bins = 20 # nb bins
499+
500+
# Gaussian distributions
501+
A = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
502+
503+
# creating matrix A containing all distributions
504+
A = A[:, None]
505+
506+
# assymetric loss matrix + normalization
507+
rng = np.random.RandomState(42)
508+
M = rng.randn(n_bins, n_bins) ** 2
509+
M /= M.max()
510+
511+
A_nx, M_nx = nx.from_numpy(A, M)
512+
reg = 1e-2
513+
514+
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log":
515+
with pytest.raises(NotImplementedError):
516+
ot.bregman.barycenter(A_nx, M_nx, reg, method=method)
517+
else:
518+
# wasserstein
519+
bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn)
520+
bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True)
521+
bary_wass = nx.to_numpy(bary_wass)
522+
523+
np.testing.assert_allclose(1, np.sum(bary_wass))
524+
np.testing.assert_allclose(bary_wass, bary_wass_np)
525+
526+
ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
527+
528+
529+
493530
@pytest.mark.parametrize("method, verbose, warn",
494531
product(["sinkhorn", "sinkhorn_log"],
495532
[True, False], [True, False]))

0 commit comments

Comments
 (0)