-
Notifications
You must be signed in to change notification settings - Fork 528
Closed
Description
Hi,
First of all thank you for putting this library together!
I believe you have made a small mistake in the implementation of emd_1d at the below line: you should pass in the sorted arrays of weights.
https://github.com/rflamary/POT/blob/fa06bb377d083c61f1ac0b067aeeab0fca2b5e7b/ot/lp/__init__.py#L659
I've put a reproduction of the problem below.
Thanks,
Adrien
import numpy as np
import ot
import scipy.stats as st
n = 100
x = np.random.normal(0., 1., n)
y = np.random.normal(0.,1., n)
w_x = np.random.uniform(0., 1., n)
w_x /= w_x.sum()
w_y = np.random.uniform(0., 1., n)
w_y /= w_y.sum()
random_index_x = np.random.choice(n, n, replace=False)
random_index_y = np.random.choice(n, n, replace=False)
assert abs(ot.emd2_1d(x[random_index_x], y[random_index_y], w_x, w_y) - ot.emd2_1d(x, y, w_x, w_y)) > 1e-6, 'This should not have raised'
wasserstein_1 = ot.emd2_1d(x, y, w_x, w_y, metric='minkowski', p=1)
assert abs(wasserstein_1 - st.wasserstein_distance(x, y, w_x, w_y)) < 1e-6, 'This should be true'
M = ot.utils.cdist(x[:, None], y[: , None], metric='minkowski', p=1)
wasserstein_1_bis = ot.emd2(w_x, w_y, M)
assert abs(wasserstein_1 - wasserstein_1_bis) < 1e-6, 'This should be true'
Metadata
Metadata
Assignees
Labels
No labels