Skip to content

Commit 8cc8dd2

Browse files
framunozrflamary
andauthored
[FIX] Refactor the function utils.cost_normalization to work with multiple backends (#472)
* [FEAT] Add the 'median' method to the backend base class and implements this method in the Numpy, Pytorch, Jax and Cupy backends * [TEST] Modify the 'cost_normalization' test to multiple backends * [REFACTOR] Refactor the 'utils.cost_normalization' function for multiple backends * [TEST] Update backend tests for method 'median' * [DEBUG] Fix the error in the test in the 'median' method with PyTorch backend * [TEST] Add the edge case where the 'median' method is not yet implemented in the Tensorflow backend. * [FEAT] Implement the 'median' method in the Tensorflow backend using Numpy * [DEBUG] For compatibility reasons, the median method in the Pytorch backend change using numpy * [DEBUG] The 'median' method checks the Pytorch version to decide whether to use torch.quantile or numpy * Add changes to RELEASES.md --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 03341c6 commit 8cc8dd2

File tree

5 files changed

+73
-15
lines changed

5 files changed

+73
-15
lines changed

RELEASES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
77
- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
88
- Add tests on GPU for master branch and approved PR (PR #473)
9+
- Add `median` method to all inherited classes of `backend.Backend` (PR #472)
910

1011
#### Closed issues
1112

@@ -16,6 +17,7 @@
1617
- Faster Bures-Wasserstein distance with NumPy backend (PR #468)
1718
- Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471)
1819
- Fix issue with ot.barycenter_stabilized when used with PyTorch tensors and log=True (RP #474)
20+
- Fix `utils.cost_normalization` function issue to work with multiple backends (PR #472)
1921

2022
## 0.9.0
2123

ot/backend.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,16 @@ def mean(self, a, axis=None):
574574
"""
575575
raise NotImplementedError()
576576

577+
def median(self, a, axis=None):
578+
r"""
579+
Computes the median of a tensor along given dimensions.
580+
581+
This function follows the api from :any:`numpy.median`
582+
583+
See: https://numpy.org/doc/stable/reference/generated/numpy.median.html
584+
"""
585+
raise NotImplementedError()
586+
577587
def std(self, a, axis=None):
578588
r"""
579589
Computes the standard deviation of a tensor along given dimensions.
@@ -1123,6 +1133,9 @@ def argmin(self, a, axis=None):
11231133
def mean(self, a, axis=None):
11241134
return np.mean(a, axis=axis)
11251135

1136+
def median(self, a, axis=None):
1137+
return np.median(a, axis=axis)
1138+
11261139
def std(self, a, axis=None):
11271140
return np.std(a, axis=axis)
11281141

@@ -1482,6 +1495,9 @@ def argmin(self, a, axis=None):
14821495
def mean(self, a, axis=None):
14831496
return jnp.mean(a, axis=axis)
14841497

1498+
def median(self, a, axis=None):
1499+
return jnp.median(a, axis=axis)
1500+
14851501
def std(self, a, axis=None):
14861502
return jnp.std(a, axis=axis)
14871503

@@ -1899,6 +1915,22 @@ def mean(self, a, axis=None):
18991915
else:
19001916
return torch.mean(a)
19011917

1918+
def median(self, a, axis=None):
1919+
from packaging import version
1920+
# Since version 1.11.0, interpolation is available
1921+
if version.parse(torch.__version__) >= version.parse("1.11.0"):
1922+
if axis is not None:
1923+
return torch.quantile(a, 0.5, interpolation="midpoint", dim=axis)
1924+
else:
1925+
return torch.quantile(a, 0.5, interpolation="midpoint")
1926+
1927+
# Else, use numpy
1928+
warnings.warn("The median is being computed using numpy and the array has been detached "
1929+
"in the Pytorch backend.")
1930+
a_ = self.to_numpy(a)
1931+
a_median = np.median(a_, axis=axis)
1932+
return self.from_numpy(a_median, type_as=a)
1933+
19021934
def std(self, a, axis=None):
19031935
if axis is not None:
19041936
return torch.std(a, dim=axis, unbiased=False)
@@ -2289,6 +2321,9 @@ def argmin(self, a, axis=None):
22892321
def mean(self, a, axis=None):
22902322
return cp.mean(a, axis=axis)
22912323

2324+
def median(self, a, axis=None):
2325+
return cp.median(a, axis=axis)
2326+
22922327
def std(self, a, axis=None):
22932328
return cp.std(a, axis=axis)
22942329

@@ -2678,6 +2713,13 @@ def argmin(self, a, axis=None):
26782713
def mean(self, a, axis=None):
26792714
return tnp.mean(a, axis=axis)
26802715

2716+
def median(self, a, axis=None):
2717+
warnings.warn("The median is being computed using numpy and the array has been detached "
2718+
"in the Tensorflow backend.")
2719+
a_ = self.to_numpy(a)
2720+
a_median = np.median(a_, axis=axis)
2721+
return self.from_numpy(a_median, type_as=a)
2722+
26812723
def std(self, a, axis=None):
26822724
return tnp.std(a, axis=axis)
26832725

ot/utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,16 +359,18 @@ def cost_normalization(C, norm=None):
359359
The input cost matrix normalized according to given norm.
360360
"""
361361

362+
nx = get_backend(C)
363+
362364
if norm is None:
363365
pass
364366
elif norm == "median":
365-
C /= float(np.median(C))
367+
C /= float(nx.median(C))
366368
elif norm == "max":
367-
C /= float(np.max(C))
369+
C /= float(nx.max(C))
368370
elif norm == "log":
369-
C = np.log(1 + C)
371+
C = nx.log(1 + C)
370372
elif norm == "loglog":
371-
C = np.log1p(np.log1p(C))
373+
C = nx.log(1 + nx.log(1 + C))
372374
else:
373375
raise ValueError('Norm %s is not a valid option.\n'
374376
'Valid options are:\n'

test/test_backend.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def test_empty_backend():
221221
nx.argmin(M)
222222
with pytest.raises(NotImplementedError):
223223
nx.mean(M)
224+
with pytest.raises(NotImplementedError):
225+
nx.median(M)
224226
with pytest.raises(NotImplementedError):
225227
nx.std(M)
226228
with pytest.raises(NotImplementedError):
@@ -519,6 +521,10 @@ def test_func_backends(nx):
519521
lst_b.append(nx.to_numpy(A))
520522
lst_name.append('mean')
521523

524+
A = nx.median(Mb)
525+
lst_b.append(nx.to_numpy(A))
526+
lst_name.append('median')
527+
522528
A = nx.std(Mb)
523529
lst_b.append(nx.to_numpy(A))
524530
lst_name.append('std')

test/test_utils.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -270,25 +270,31 @@ def test_clean_zeros():
270270
assert len(b) == n - nz2
271271

272272

273-
def test_cost_normalization():
273+
def test_cost_normalization(nx):
274274

275275
C = np.random.rand(10, 10)
276+
C1 = nx.from_numpy(C)
276277

277278
# does nothing
278-
M0 = ot.utils.cost_normalization(C)
279-
np.testing.assert_allclose(C, M0)
279+
M0 = ot.utils.cost_normalization(C1)
280+
M1 = nx.to_numpy(M0)
281+
np.testing.assert_allclose(C, M1)
280282

281-
M = ot.utils.cost_normalization(C, 'median')
282-
np.testing.assert_allclose(np.median(M), 1)
283+
M = ot.utils.cost_normalization(C1, 'median')
284+
M1 = nx.to_numpy(M)
285+
np.testing.assert_allclose(np.median(M1), 1)
283286

284-
M = ot.utils.cost_normalization(C, 'max')
285-
np.testing.assert_allclose(M.max(), 1)
287+
M = ot.utils.cost_normalization(C1, 'max')
288+
M1 = nx.to_numpy(M)
289+
np.testing.assert_allclose(M1.max(), 1)
286290

287-
M = ot.utils.cost_normalization(C, 'log')
288-
np.testing.assert_allclose(M.max(), np.log(1 + C).max())
291+
M = ot.utils.cost_normalization(C1, 'log')
292+
M1 = nx.to_numpy(M)
293+
np.testing.assert_allclose(M1.max(), np.log(1 + C).max())
289294

290-
M = ot.utils.cost_normalization(C, 'loglog')
291-
np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max())
295+
M = ot.utils.cost_normalization(C1, 'loglog')
296+
M1 = nx.to_numpy(M)
297+
np.testing.assert_allclose(M1.max(), np.log(1 + np.log(1 + C)).max())
292298

293299

294300
def test_check_params():

0 commit comments

Comments
 (0)