Skip to content

Commit 0b59002

Browse files
authored
[MRG] Only Minkowski metrics can be used for emd_1d (#670)
* only Minkowski metrics can be used for emd_1d * flake8 * added basic test * Fixed test * Updated RELEASES.md
1 parent 36c9252 commit 0b59002

File tree

4 files changed

+26
-15
lines changed

4 files changed

+26
-15
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#### Closed issues
1616
- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)
17+
- Fixed `ot.emd_1d` and `ot.emd2_1d` incorrectly allowing any metric (PR #670, Issue #669)
1718

1819
## 0.9.4
1920
*June 2024*

ot/lp/emd_wrap.pyx

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
141141
v : (nt,) ndarray, float64
142142
Target dirac locations (on the real line)
143143
metric: str, optional (default='sqeuclidean')
144-
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
145-
Due to implementation details, this function runs faster when
146-
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
147-
are used.
144+
Metric to be used. Only works with either of the strings
145+
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`.
148146
p: float, optional (default=1.0)
149147
The p-norm to apply for if metric='minkowski'
150148
@@ -182,8 +180,9 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights,
182180
elif metric == 'minkowski':
183181
m_ij = math.pow(math.fabs(u[i] - v[j]), p)
184182
else:
185-
m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
186-
metric=metric)[0, 0]
183+
raise ValueError("Solver for EMD in 1d only supports metrics " +
184+
"from the following list: " +
185+
"`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`")
187186
if w_i < w_j or j == m - 1:
188187
cost += m_ij * w_i
189188
G[cur_idx] = w_i

ot/lp/solver_1d.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
152152
- x_a and x_b are the samples
153153
- a and b are the sample weights
154154
155-
When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
155+
This implementation only supports metrics
156+
of the form :math:`d(x, y) = |x - y|^p`.
156157
157158
Uses the algorithm detailed in [1]_
158159
@@ -167,9 +168,8 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
167168
b : (nt,) ndarray, float64, optional
168169
Target histogram (default is uniform weight)
169170
metric: str, optional (default='sqeuclidean')
170-
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
171-
Due to implementation details, this function runs faster when
172-
`'sqeuclidean'`, `'cityblock'`, or `'euclidean'` metrics are used.
171+
Metric to be used. Only works with either of the strings
172+
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`.
173173
p: float, optional (default=1.0)
174174
The p-norm to apply for if metric='minkowski'
175175
dense: boolean, optional (default=True)
@@ -234,6 +234,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
234234
"emd_1d should only be used with monodimensional data"
235235
assert (x_b.ndim == 1 or x_b.ndim == 2 and x_b.shape[1] == 1), \
236236
"emd_1d should only be used with monodimensional data"
237+
if metric not in ['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']:
238+
raise ValueError(
239+
"Solver for EMD in 1d only supports metrics " +
240+
"from the following list: " +
241+
"`['sqeuclidean', 'minkowski', 'cityblock', 'euclidean']`"
242+
)
237243

238244
# if empty array given then use uniform distributions
239245
if a is None or a.ndim == 0 or len(a) == 0:
@@ -300,7 +306,8 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
300306
- x_a and x_b are the samples
301307
- a and b are the sample weights
302308
303-
When 'minkowski' is used as a metric, :math:`d(x, y) = |x - y|^p`.
309+
This implementation only supports metrics
310+
of the form :math:`d(x, y) = |x - y|^p`.
304311
305312
Uses the algorithm detailed in [1]_
306313
@@ -315,10 +322,8 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
315322
b : (nt,) ndarray, float64, optional
316323
Target histogram (default is uniform weight)
317324
metric: str, optional (default='sqeuclidean')
318-
Metric to be used. Only strings listed in :func:`ot.dist` are accepted.
319-
Due to implementation details, this function runs faster when
320-
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'` metrics
321-
are used.
325+
Metric to be used. Only works with either of the strings
326+
`'sqeuclidean'`, `'minkowski'`, `'cityblock'`, or `'euclidean'`.
322327
p: float, optional (default=1.0)
323328
The p-norm to apply for if metric='minkowski'
324329
dense: boolean, optional (default=True)

test/test_1d_solver.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ def test_emd_1d_emd2_1d_with_weights():
5050
np.testing.assert_allclose(w_u, G.sum(1))
5151
np.testing.assert_allclose(w_v, G.sum(0))
5252

53+
# check that an error is raised if the metric is not a Minkowski one
54+
np.testing.assert_raises(ValueError, ot.emd_1d,
55+
u, v, w_u, w_v, metric='cosine')
56+
np.testing.assert_raises(ValueError, ot.emd2_1d,
57+
u, v, w_u, w_v, metric='cosine')
58+
5359

5460
def test_wasserstein_1d(nx):
5561
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)