diff --git a/RELEASES.md b/RELEASES.md index cf95e489a..1ec999811 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -14,6 +14,7 @@ #### Closed issues - Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648) +- Fixed `ot.emd_1d` and `ot.emd2_1d` incorrectly allowing any metric (PR #670, Issue #669) ## 0.9.4 *June 2024* diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index e5cec89c6..53df54fc3 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -141,10 +141,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, v : (nt,) ndarray, float64 Target dirac locations (on the real line) metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in :func:`ot.dist` are accepted. - Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics - are used. + Metric to be used. Only works with either of the strings + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. p: float, optional (default=1.0) The p-norm to apply for if metric='minkowski' @@ -182,8 +180,9 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, elif metric == 'minkowski': m_ij = math.pow(math.fabs(u[i] - v[j]), p) else: - m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), - metric=metric)[0, 0] + raise ValueError("Solver for EMD in 1d only supports metrics " + + "from the following list: " + + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`") if w_i < w_j or j == m - 1: cost += m_ij * w_i G[cur_idx] = w_i diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index d9395c8d4..6d97303e2 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -152,7 +152,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - x_a and x_b are the samples - a and b are the sample weights - When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + This implementation only supports metrics + of the form :math:`d(x, y) = |x - y|^p`. Uses the algorithm detailed in [1]_ @@ -167,9 +168,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, b : (nt,) ndarray, float64, optional Target histogram (default is uniform weight) metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in :func:`ot.dist` are accepted. - Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used. + Metric to be used. Only works with either of the strings + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. p: float, optional (default=1.0) The p-norm to apply for if metric='minkowski' dense: boolean, optional (default=True) @@ -234,6 +234,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, "emd_1d should only be used with monodimensional data" assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \ "emd_1d should only be used with monodimensional data" + if metric not in ['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']: + raise ValueError( + "Solver for EMD in 1d only supports metrics " + + "from the following list: " + + "`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`" + ) # if empty array given then use uniform distributions if a is None or a.ndim == 0 or len(a) == 0: @@ -300,7 +306,8 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, - x_a and x_b are the samples - a and b are the sample weights - When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`. + This implementation only supports metrics + of the form :math:`d(x, y) = |x - y|^p`. Uses the algorithm detailed in [1]_ @@ -315,10 +322,8 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, b : (nt,) ndarray, float64, optional Target histogram (default is uniform weight) metric: str, optional (default='sqeuclidean') - Metric to be used. Only strings listed in :func:`ot.dist` are accepted. - Due to implementation details, this function runs faster when - `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics - are used. + Metric to be used. Only works with either of the strings + `'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`. p: float, optional (default=1.0) The p-norm to apply for if metric='minkowski' dense: boolean, optional (default=True) diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 131757610..8fec3e346 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -50,6 +50,12 @@ def test_emd_1d_emd2_1d_with_weights(): np.testing.assert_allclose(w_u, G.sum(1)) np.testing.assert_allclose(w_v, G.sum(0)) + # check that an error is raised if the metric is not a Minkowski one + np.testing.assert_raises(ValueError, ot.emd_1d, + u, v, w_u, w_v, metric='cosine') + np.testing.assert_raises(ValueError, ot.emd2_1d, + u, v, w_u, w_v, metric='cosine') + def test_wasserstein_1d(nx): rng = np.random.RandomState(0)