From df08b967aefcab3938a8f6829ada23a5ae67f4c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Wed, 3 May 2023 18:50:34 -0400 Subject: [PATCH 01/10] [FEAT] Add the 'median' method to the backend base class and implements this method in the Numpy, Pytorch, Jax and Cupy backends --- ot/backend.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/ot/backend.py b/ot/backend.py index a82c4486a..6f5448fa1 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -574,6 +574,16 @@ def mean(self, a, axis=None): """ raise NotImplementedError() + def median(self, a, axis=None): + r""" + Computes the median of a tensor along given dimensions. + + This function follows the api from :any:`numpy.median` + + See: https://numpy.org/doc/stable/reference/generated/numpy.median.html + """ + raise NotImplementedError() + def std(self, a, axis=None): r""" Computes the standard deviation of a tensor along given dimensions. @@ -1115,6 +1125,9 @@ def argmin(self, a, axis=None): def mean(self, a, axis=None): return np.mean(a, axis=axis) + def median(self, a, axis=None): + return np.median(a, axis=axis) + def std(self, a, axis=None): return np.std(a, axis=axis) @@ -1470,6 +1483,9 @@ def argmin(self, a, axis=None): def mean(self, a, axis=None): return jnp.mean(a, axis=axis) + def median(self, a, axis=None): + return jnp.median(a, axis=axis) + def std(self, a, axis=None): return jnp.std(a, axis=axis) @@ -1884,6 +1900,12 @@ def mean(self, a, axis=None): else: return torch.mean(a) + def median(self, a, axis=None): + if axis is not None: + return torch.median(a, dim=axis) + else: + return torch.median(a) + def std(self, a, axis=None): if axis is not None: return torch.std(a, dim=axis, unbiased=False) @@ -2271,6 +2293,9 @@ def argmin(self, a, axis=None): def mean(self, a, axis=None): return cp.mean(a, axis=axis) + def median(self, a, axis=None): + return cp.median(a, axis=axis) + def std(self, a, axis=None): return cp.std(a, axis=axis) From 897904cc0fe7ddff7b00e9210f911856cc05c12d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Wed, 3 May 2023 18:51:53 -0400 Subject: [PATCH 02/10] [TEST] Modify the 'cost_normalization' test to multiple backends --- test/test_utils.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 658214d21..87f4dc48d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -270,25 +270,31 @@ def test_clean_zeros(): assert len(b) == n - nz2 -def test_cost_normalization(): +def test_cost_normalization(nx): C = np.random.rand(10, 10) + C1 = nx.from_numpy(C) # does nothing - M0 = ot.utils.cost_normalization(C) - np.testing.assert_allclose(C, M0) + M0 = ot.utils.cost_normalization(C1) + M1 = nx.to_numpy(M0) + np.testing.assert_allclose(C, M1) - M = ot.utils.cost_normalization(C, 'median') - np.testing.assert_allclose(np.median(M), 1) + M = ot.utils.cost_normalization(C1, 'median') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(np.median(M1), 1) - M = ot.utils.cost_normalization(C, 'max') - np.testing.assert_allclose(M.max(), 1) + M = ot.utils.cost_normalization(C1, 'max') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(M1.max(), 1) - M = ot.utils.cost_normalization(C, 'log') - np.testing.assert_allclose(M.max(), np.log(1 + C).max()) + M = ot.utils.cost_normalization(C1, 'log') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(M1.max(), np.log(1 + C).max()) - M = ot.utils.cost_normalization(C, 'loglog') - np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max()) + M = ot.utils.cost_normalization(C1, 'loglog') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(M1.max(), np.log(1 + np.log(1 + C)).max()) def test_check_params(): From 2cff3a3bca6b0c046bd3e60d6c80f72c4ff196d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Wed, 3 May 2023 18:53:17 -0400 Subject: [PATCH 03/10] [REFACTOR] Refactor the 'utils.cost_normalization' function for multiple backends --- ot/utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ot/utils.py b/ot/utils.py index 3343028ec..091b26813 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -359,16 +359,18 @@ def cost_normalization(C, norm=None): The input cost matrix normalized according to given norm. """ + nx = get_backend(C) + if norm is None: pass elif norm == "median": - C /= float(np.median(C)) + C /= float(nx.median(C)) elif norm == "max": - C /= float(np.max(C)) + C /= float(nx.max(C)) elif norm == "log": - C = np.log(1 + C) + C = nx.log(1 + C) elif norm == "loglog": - C = np.log1p(np.log1p(C)) + C = nx.log(1 + nx.log(1 + C)) else: raise ValueError('Norm %s is not a valid option.\n' 'Valid options are:\n' From 61346e7404615aef340157362c0a208e722edd13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Wed, 3 May 2023 18:55:03 -0400 Subject: [PATCH 04/10] [TEST] Update backend tests for method 'median' --- test/test_backend.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/test_backend.py b/test/test_backend.py index 5351e5283..3807b9fb2 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -221,6 +221,8 @@ def test_empty_backend(): nx.argmin(M) with pytest.raises(NotImplementedError): nx.mean(M) + with pytest.raises(NotImplementedError): + nx.median(M) with pytest.raises(NotImplementedError): nx.std(M) with pytest.raises(NotImplementedError): @@ -511,6 +513,10 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('mean') + A = nx.median(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('median') + A = nx.std(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append('std') From caa3ea37a08913b61044a4a4291086603aaf310e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Thu, 4 May 2023 13:32:59 -0400 Subject: [PATCH 05/10] [DEBUG] Fix the error in the test in the 'median' method with PyTorch backend --- ot/backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 6f5448fa1..c2a2ece8a 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1902,9 +1902,9 @@ def mean(self, a, axis=None): def median(self, a, axis=None): if axis is not None: - return torch.median(a, dim=axis) + return torch.quantile(a, 0.5, interpolation="midpoint", dim=axis) else: - return torch.median(a) + return torch.quantile(a, 0.5, interpolation="midpoint") def std(self, a, axis=None): if axis is not None: From 7501a4360bdc920d59ae1cf2620b0c6034b69e9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Thu, 4 May 2023 14:10:48 -0400 Subject: [PATCH 06/10] [TEST] Add the edge case where the 'median' method is not yet implemented in the Tensorflow backend. --- ot/backend.py | 5 +++++ test/test_backend.py | 10 +++++++--- test/test_utils.py | 11 ++++++++--- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index c2a2ece8a..7f9b59c13 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -2682,6 +2682,11 @@ def argmin(self, a, axis=None): def mean(self, a, axis=None): return tnp.mean(a, axis=axis) + # This could be a tentative implementation, in case of installing tensorflow_probability + # def median(self, a, axis=None): + # import tensorflow_probability as tfp + # return tfp.stats.percentile(a, 50., axis=axis, interpolation="midpoint") + def std(self, a, axis=None): return tnp.std(a, axis=axis) diff --git a/test/test_backend.py b/test/test_backend.py index 3807b9fb2..6d78dd9b5 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -513,9 +513,13 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('mean') - A = nx.median(Mb) - lst_b.append(nx.to_numpy(A)) - lst_name.append('median') + if not nx.__name__ == 'tf': + A = nx.median(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('median') + else: + with pytest.raises(NotImplementedError): + nx.median(Mb) A = nx.std(Mb) lst_b.append(nx.to_numpy(A)) diff --git a/test/test_utils.py b/test/test_utils.py index 87f4dc48d..cb22dc84c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -280,9 +280,14 @@ def test_cost_normalization(nx): M1 = nx.to_numpy(M0) np.testing.assert_allclose(C, M1) - M = ot.utils.cost_normalization(C1, 'median') - M1 = nx.to_numpy(M) - np.testing.assert_allclose(np.median(M1), 1) + # This function still not work with Tensorflow backend + if not nx.__name__ == 'tf': + M = ot.utils.cost_normalization(C1, 'median') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(np.median(M1), 1) + else: + with pytest.raises(NotImplementedError): + ot.utils.cost_normalization(C1, 'median') M = ot.utils.cost_normalization(C1, 'max') M1 = nx.to_numpy(M) From 95fde336d34e822b4cda11cbfec845fdcab1facb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Fri, 5 May 2023 12:13:59 -0400 Subject: [PATCH 07/10] [FEAT] Implement the 'median' method in the Tensorflow backend using Numpy --- ot/backend.py | 10 ++++++---- test/test_backend.py | 10 +++------- test/test_utils.py | 11 +++-------- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index cc56b7a1c..c621cdf1f 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -2703,10 +2703,12 @@ def argmin(self, a, axis=None): def mean(self, a, axis=None): return tnp.mean(a, axis=axis) - # This could be a tentative implementation, in case of installing tensorflow_probability - # def median(self, a, axis=None): - # import tensorflow_probability as tfp - # return tfp.stats.percentile(a, 50., axis=axis, interpolation="midpoint") + def median(self, a, axis=None): + warnings.warn("The median it is being computed using numpy and the array is detached in " + "the Tensorflow backend.") + a_ = self.to_numpy(a) + a_median = np.median(a_, axis=axis) + return self.from_numpy(a_median, type_as=a) def std(self, a, axis=None): return tnp.std(a, axis=axis) diff --git a/test/test_backend.py b/test/test_backend.py index 89ef38fee..799ac54d3 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -521,13 +521,9 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('mean') - if not nx.__name__ == 'tf': - A = nx.median(Mb) - lst_b.append(nx.to_numpy(A)) - lst_name.append('median') - else: - with pytest.raises(NotImplementedError): - nx.median(Mb) + A = nx.median(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('median') A = nx.std(Mb) lst_b.append(nx.to_numpy(A)) diff --git a/test/test_utils.py b/test/test_utils.py index cb22dc84c..87f4dc48d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -280,14 +280,9 @@ def test_cost_normalization(nx): M1 = nx.to_numpy(M0) np.testing.assert_allclose(C, M1) - # This function still not work with Tensorflow backend - if not nx.__name__ == 'tf': - M = ot.utils.cost_normalization(C1, 'median') - M1 = nx.to_numpy(M) - np.testing.assert_allclose(np.median(M1), 1) - else: - with pytest.raises(NotImplementedError): - ot.utils.cost_normalization(C1, 'median') + M = ot.utils.cost_normalization(C1, 'median') + M1 = nx.to_numpy(M) + np.testing.assert_allclose(np.median(M1), 1) M = ot.utils.cost_normalization(C1, 'max') M1 = nx.to_numpy(M) From aa4ad9bb95025e1ca32394ff1aeaa61be31dadec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Fri, 5 May 2023 23:25:41 -0400 Subject: [PATCH 08/10] [DEBUG] For compatibility reasons, the median method in the Pytorch backend change using numpy --- ot/backend.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index c621cdf1f..dc2880e1c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1916,10 +1916,15 @@ def mean(self, a, axis=None): return torch.mean(a) def median(self, a, axis=None): - if axis is not None: - return torch.quantile(a, 0.5, interpolation="midpoint", dim=axis) - else: - return torch.quantile(a, 0.5, interpolation="midpoint") + # if axis is not None: + # return torch.quantile(a, 0.5, interpolation="midpoint", dim=axis) + # else: + # return torch.quantile(a, 0.5, interpolation="midpoint") + warnings.warn("The median is being computed using numpy and the array has been detached " + "in the Pytorch backend.") + a_ = self.to_numpy(a) + a_median = np.median(a_, axis=axis) + return self.from_numpy(a_median, type_as=a) def std(self, a, axis=None): if axis is not None: @@ -2704,8 +2709,8 @@ def mean(self, a, axis=None): return tnp.mean(a, axis=axis) def median(self, a, axis=None): - warnings.warn("The median it is being computed using numpy and the array is detached in " - "the Tensorflow backend.") + warnings.warn("The median is being computed using numpy and the array has been detached " + "in the Tensorflow backend.") a_ = self.to_numpy(a) a_median = np.median(a_, axis=axis) return self.from_numpy(a_median, type_as=a) From 3c1186313401210d0e8ffaccdc92c4ed32dec4c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Sat, 6 May 2023 20:49:23 -0400 Subject: [PATCH 09/10] [DEBUG] The 'median' method checks the Pytorch version to decide whether to use torch.quantile or numpy --- ot/backend.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index dc2880e1c..9aa14e6b2 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1916,10 +1916,15 @@ def mean(self, a, axis=None): return torch.mean(a) def median(self, a, axis=None): - # if axis is not None: - # return torch.quantile(a, 0.5, interpolation="midpoint", dim=axis) - # else: - # return torch.quantile(a, 0.5, interpolation="midpoint") + from packaging import version + # Since version 1.11.0, interpolation is available + if version.parse(torch.__version__) >= version.parse("1.11.0"): + if axis is not None: + return torch.quantile(a, 0.5, interpolation="midpoint", dim=axis) + else: + return torch.quantile(a, 0.5, interpolation="midpoint") + + # Else, use numpy warnings.warn("The median is being computed using numpy and the array has been detached " "in the Pytorch backend.") a_ = self.to_numpy(a) From 7108ef310d709e4d945f87f4a2a794cae0a7e04f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Mu=C3=B1oz?= Date: Tue, 9 May 2023 11:08:18 -0400 Subject: [PATCH 10/10] Add changes to RELEASES.md --- RELEASES.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index 02fddad91..2e79d7562 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,7 @@ - Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463) - Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459) - Add tests on GPU for master branch and approved PR (PR #473) +- Add `median` method to all inherited classes of `backend.Backend` (PR #472) #### Closed issues @@ -15,6 +16,7 @@ - Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466) - Faster Bures-Wasserstein distance with NumPy backend (PR #468) - Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471) +- Fix `utils.cost_normalization` function issue to work with multiple backends (PR #472) ## 0.9.0