From f309f8096fd392ed33a485136c03b92b9c8e470d Mon Sep 17 00:00:00 2001 From: Clement Date: Fri, 3 Feb 2023 17:57:42 +0100 Subject: [PATCH 01/16] W circle + SSW --- CONTRIBUTORS.md | 1 + README.md | 9 +- ot/__init__.py | 9 +- ot/backend.py | 184 +++++++++++++++++-- ot/lp/__init__.py | 5 +- ot/lp/solver_1d.py | 447 ++++++++++++++++++++++++++++++++++++++++++++- ot/sliced.py | 93 +++++++++- 7 files changed, 727 insertions(+), 21 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 67d8337ff..ccd0fd1a1 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -41,6 +41,7 @@ The contributors to this library are: * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) +* [Clément Bonet](https://clbonet.github.io) (Spherical Sliced-Wasserstein) ## Acknowledgments diff --git a/README.md b/README.md index 7c9475b80..1d24bc358 100644 --- a/README.md +++ b/README.md @@ -292,4 +292,11 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. -[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. \ No newline at end of file +[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + +[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + +[45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + +[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2022). [Spherical sliced-wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. \ No newline at end of file diff --git a/ot/__init__.py b/ot/__init__.py index 0b55e0c56..c109d1425 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -38,12 +38,12 @@ from . import gaussian # OT functions -from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d +from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d, binary_search_circle, w1_circle, w_circle from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm -from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance +from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance, sliced_wasserstein_sphere from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport @@ -60,8 +60,9 @@ 'emd2_1d', 'wasserstein_1d', 'backend', 'gaussian', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', - 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', + 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'factored_optimal_transport', 'solve', - 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers'] + 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', + 'binary_search_circle', 'w1_circle', 'w_circle'] diff --git a/ot/backend.py b/ot/backend.py index 337e040d3..cf332567d 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -534,7 +534,7 @@ def concatenate(self, arrays, axis=0): """ raise NotImplementedError() - def zero_pad(self, a, pad_width): + def zero_pad(self, a, pad_width, value): r""" Pads a tensor. @@ -895,6 +895,63 @@ def is_floating_point(self, a): """ raise NotImplementedError() + def tile(self, a, reps): + r""" + Construct an array by repeating a the number of times given by reps + + See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html + """ + raise NotImplementedError() + + + def floor(self, a): + r""" + Return the floor of the input element-wise + + See: https://numpy.org/doc/stable/reference/generated/numpy.floor.html + """ + raise NotImplementedError() + + def prod(self, a, axis): + r""" + Return the product of all elements. + + See: https://pytorch.org/docs/stable/generated/torch.prod.html + """ + raise NotImplementedError() + + def sort2(self, a, axis=None): + r""" + Return the sorted array and the indices to sort the array + + See: https://pytorch.org/docs/stable/generated/torch.sort.html + """ + raise NotImplementedError() + + def qr(self, a): + r""" + Return the QR factorization + + See: https://pytorch.org/docs/stable/generated/torch.linalg.qr.html + """ + raise NotImplementedError() + + def atan2(self, a, b): + r""" + Element wise arctangent + + See: https://numpy.org/doc/stable/reference/generated/numpy.arctan2.html + """ + raise NotImplementedError() + + def transpose(self, a, dim0, dim1): + r""" + Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped. + + See: https://pytorch.org/docs/stable/generated/torch.transpose.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1039,8 +1096,8 @@ def take_along_axis(self, arr, indices, axis): def concatenate(self, arrays, axis=0): return np.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return np.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return np.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return np.argmax(a, axis=axis) @@ -1185,6 +1242,27 @@ def array_equal(self, a, b): def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return np.tile(a, reps) + + def floor(self, a): + return np.floor(a) + + def prod(self, a, axis): + return np.prod(a, axis=axis) + + def sort2(self, a, axis): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return np.linalg.qr(a) + + def atan2(self, a, b): + return np.arctan2(a, b) + + def transpose(self, a, dim0, dim1): + return np.transpose(a, axes=[0,dim1,dim0]) + class JaxBackend(Backend): """ @@ -1351,8 +1429,8 @@ def take_along_axis(self, arr, indices, axis): def concatenate(self, arrays, axis=0): return jnp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return jnp.pad(a, pad_width) + def zero_pad(self, a, value=0): + return jnp.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return jnp.argmax(a, axis=axis) @@ -1511,6 +1589,27 @@ def array_equal(self, a, b): def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return jnp.numpy.tile(a, reps) + + def floor(self, a): + return jnp.numpy.floor(a) + + def prod(self, a, axis): + return jnp.numpy.prod(a, axis=axis) + + def sort2(self, a, axis): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return jnp.numpy.linalg.qr(a) + + def atan2(self, a, b): + return jnp.numpy.arctan2(a, b) + + def transpose(self, a, dim0, dim1): + return jnp.numpy.transpose(a, axes=[0,dim1,dim0]) + class TorchBackend(Backend): """ @@ -1729,13 +1828,13 @@ def take_along_axis(self, arr, indices, axis): def concatenate(self, arrays, axis=0): return torch.cat(arrays, dim=axis) - def zero_pad(self, a, pad_width): + def zero_pad(self, a, pad_width, value=0): from torch.nn.functional import pad # pad_width is an array of ndim tuples indicating how many 0 before and after # we need to add. We first need to make it compliant with torch syntax, that # starts with the last dim, then second last, etc. how_pad = tuple(element for tupl in pad_width[::-1] for element in tupl) - return pad(a, how_pad) + return pad(a, how_pad, value=value) def argmax(self, a, axis=None): return torch.argmax(a, dim=axis) @@ -1934,6 +2033,27 @@ def array_equal(self, a, b): def is_floating_point(self, a): return a.dtype.is_floating_point + def tile(self, a, reps): + return a.repeat(reps) + + def floor(self, a): + return torch.floor(a) + + def prod(self, a, axis): + return torch.prod(a, dim=axis) + + def sort2(self, a, axis): + return torch.sort(a, axis) + + def qr(self, a): + return torch.linalg.qr(a) + + def atan2(self, a, b): + return torch.atan2(a, b) + + def transpose(self, a, dim0, dim1): + return torch.transpose(a, dim0, dim1) + class CupyBackend(Backend): # pragma: no cover """ @@ -2096,8 +2216,8 @@ def take_along_axis(self, arr, indices, axis): def concatenate(self, arrays, axis=0): return cp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return cp.pad(a, pad_width) + def zero_pad(self, a, pad_width, value=0): + return cp.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): return cp.argmax(a, axis=axis) @@ -2284,6 +2404,27 @@ def array_equal(self, a, b): def is_floating_point(self, a): return a.dtype.kind == "f" + def tile(self, a, reps): + return cp.tile(a, reps) + + def floor(self, a): + return cp.floor(a) + + def prod(self, a, axis): + return cp.prod(a, axis=axis) + + def sort2(self, a, axis): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return cp.linalg.qr(a) + + def atan2(self, a, b): + return cp.arctan2(a, b) + + def transpose(self, a, dim0, dim1): + return cp.transpose(a, axes=[0, dim1, dim0]) + class TensorflowBackend(Backend): @@ -2454,8 +2595,8 @@ def take_along_axis(self, arr, indices, axis): def concatenate(self, arrays, axis=0): return tnp.concatenate(arrays, axis) - def zero_pad(self, a, pad_width): - return tnp.pad(a, pad_width, mode="constant") + def zero_pad(self, a, pad_width, value=0): + return tnp.pad(a, pad_width, mode="constant", constant_values=value) def argmax(self, a, axis=None): return tnp.argmax(a, axis=axis) @@ -2646,3 +2787,24 @@ def array_equal(self, a, b): def is_floating_point(self, a): return a.dtype.is_floating + + def tile(self, a, reps): + return tf.tile(a, reps) + + def floor(self, a): + return tf.floor(a) + + def prod(self, a, axis): + return tf.experimental.numpy.prod(a, axis=axis) + + def sort2(self, a, axis): + return self.sort(a, axis), self.argsort(a, axis) + + def qr(self, a): + return tf.linalg.qr(a) + + def atan2(self, a, b): + return tf.math.atan2(a, b) + + def transpose(self, a, dim0, dim1): + return tf.transpose(a, perm=[0, dim1, dim0]) \ No newline at end of file diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 17411d02b..b3a0f803c 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -20,14 +20,15 @@ # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from .solver_1d import emd_1d, emd2_1d, wasserstein_1d +from .solver_1d import emd_1d, emd2_1d, wasserstein_1d, binary_search_circle, w1_circle, w_circle from ..utils import dist, list_to_array from ..utils import parmap from ..backend import get_backend __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', - 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter'] + 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter', + 'binary_search_circle', 'w1_circle', 'w_circle'] def check_number_threads(numThreads): diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 43763a9bd..cf0d19868 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -53,7 +53,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ distributions .. math: - OT_{loss} = \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq It is formally the p-Wasserstein distance raised to the power p. We do so in a vectorized way by first building the individual quantile functions then integrating them. @@ -365,3 +365,448 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, log_emd = {'G': G} return cost, log_emd return cost + + + +def roll_cols(M, shifts): + r""" + Utils functions which allow to shift the order of each row of a 2d matrix + + Parameters + ---------- + M : (nr, nc) ndarray + Matrix to shift + shifts: int or (nr,) ndarray + + Returns + ------- + Shifted array + + Examples + -------- + >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]]) + >>> roll_cols(M, 2) + array([[2, 3, 1], + [5, 6, 4], + [8, 9, 7]]) + >>> roll_cols(M, np.array([[1],[2],[1]]))) + array([[3, 1, 2], + [5, 6, 4], + [9, 7, 8]]) + + References + ---------- + https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch + + """ + nx = get_backend(M) + + n_rows, n_cols = M.shape + + arange1 = nx.tile(nx.arange(n_cols).reshape((1, n_cols)), (n_rows,1)) + arange2 = (arange1 - shifts) % n_cols + + return nx.take_along_axis(M, arange2, 1) + + +def dCost(theta, u_values, v_values, u_cdf, v_cdf, p=2): + r""" Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1]) + + Parameters + ---------- + theta: array-like, shape (n_batch, n) + Cuts on the circle + u_values: array-like, shape (n_batch, n) + locations of the first empirical distribution + v_values: array-like, shape (n_batch, n) + locations of the second empirical distribution + u_cdf: array-like, shape (n_batch, n) + cdf of the first empirical distribution + v_cdf: array-like, shape (n_batch, n) + cdf of the second empirical distribution + p: float, optional = 2 + Power p used for computing the Wasserstein distance + + Returns + ------- + dCp: array-like, shape (n_batch, 1) + The batched right derivative + dCm: array-like, shape (n_batch, 1) + The batched left derivative + + References + --------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) + + v_values = nx.copy(v_values) + + n = u_values.shape[-1] + m_batch, m = v_values.shape + + v_cdf_theta = v_cdf - (theta - nx.floor(theta)) + + mask_p = v_cdf_theta>=0 + mask_n = v_cdf_theta<0 + + v_values[mask_n] += nx.floor(theta)[mask_n]+1 + v_values[mask_p] += nx.floor(theta)[mask_p] + + if nx.any(mask_n) and nx.any(mask_p): + v_cdf_theta[mask_n] += 1 + + v_cdf_theta2 = nx.copy(v_cdf_theta) + v_cdf_theta2[mask_n] = np.inf + shift = (-nx.argmin(v_cdf_theta2, axis=-1)) + + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1,1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1,1))) + v_values = nx.concatenate([v_values, nx.reshape(v_values[:,0],(-1,1))+1], axis=1) + + ## quantiles of F_u evaluated in F_v^\theta + u_index = nx.searchsorted(u_cdf, v_cdf_theta) + u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n-1), -1) + + ## Deal with 1 + u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:,0], (-1,1))+1], axis=1) + u_valuesm = nx.concatenate([u_values, nx.reshape(u_values[:,0], (-1,1))+1], axis=1) + u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right") + u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1) + + dCp = nx.sum(nx.power(nx.abs(u_icdf_theta-v_values[:,1:]), p) + -nx.power(nx.abs(u_icdf_theta-v_values[:,:-1]), p), axis=-1) + + dCm = nx.sum(nx.power(nx.abs(u_icdfm_theta-v_values[:,1:]), p) + -nx.power(nx.abs(u_icdfm_theta-v_values[:,:-1]), p), axis=-1) + + return dCp.reshape(-1,1), dCm.reshape(-1,1) + + +def Cost(theta, u_values, v_values, u_cdf, v_cdf, p): + r""" Computes the the cost (Equation (6.2) of [1]) + + Parameters + ---------- + theta: array-like, shape (n_batch, n) + Cuts on the circle + u_values: array-like, shape (n_batch, n) + locations of the first empirical distribution + v_values: array-like, shape (n_batch, n) + locations of the second empirical distribution + u_cdf: array-like, shape (n_batch, n) + cdf of the first empirical distribution + v_cdf: array-like, shape (n_batch, n) + cdf of the second empirical distribution + p: float, optional = 2 + Power p used for computing the Wasserstein distance + + Returns + ------- + ot_cost: array-like, shape (n_batch,) + OT cost evaluated at theta + + References + --------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) + + v_values = nx.copy(v_values) + + m_batch, m = v_values.shape + n_batch, n = u_values.shape + + v_cdf_theta = v_cdf -(theta - nx.floor(theta)) + + mask_p = v_cdf_theta>=0 + mask_n = v_cdf_theta<0 + + v_values[mask_n] += nx.floor(theta)[mask_n]+1 + v_values[mask_p] += nx.floor(theta)[mask_p] + + if nx.any(mask_n) and nx.any(mask_p): + v_cdf_theta[mask_n] += 1 + + ## Put negative values at the end + v_cdf_theta2 = nx.copy(v_cdf_theta) + v_cdf_theta2[mask_n] = np.inf + shift = (-nx.argmin(v_cdf_theta2, axis=-1))# .tolist() + + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1,1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1,1))) + v_values = nx.concatenate([v_values, nx.reshape(v_values[:,0], (-1,1))+1], axis=1) + + ## Compute absciss + cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1) + cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0,0),(1,0)]) + + delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1] + + ## Compute icdf + u_index = nx.searchsorted(u_cdf, cdf_axis) + u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n-1), -1) + + v_values = nx.concatenate([v_values, nx.reshape(v_values[:,0], (-1,1))+1], axis=1) + v_index = nx.searchsorted(v_cdf_theta, cdf_axis) + v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1) + + if p == 1: + ot_cost = nx.sum(delta*nx.abs(u_icdf-v_icdf), axis=-1) + else: + ot_cost = nx.sum(delta*nx.power(nx.abs(u_icdf-v_icdf), p), axis=-1) + + return ot_cost + + +def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, + Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): + r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [1]. + + .. math: + OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain + v_values : ndarray, shape (n, ...) + samples in the target domain + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + p : float, optional + Power p used for computing the Wasserstein distance + Lm : int, optional + Lower bound dC + Lp : int, optional + Upper bound dC + tm: float, optional + Lower bound theta + tp: float, optional + Upper bound theta + eps: float, optional + Stopping condition + require_sort: bool, optional + If True, sort the values. + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> binary_search_circle(u.T, v.T, p=1) + array([0.1]) + + References + ---------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + + Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html + """ + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + u_cdf = nx.cumsum(u_weights, 0).T + v_cdf = nx.cumsum(v_weights, 0).T + + u_values = u_values.T + v_values = v_values.T + + + L = max(Lm, Lp) + + tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1,1)) + tm = nx.tile(tm, (1, m)) + tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1,1)) + tp = nx.tile(tp, (1, m)) + tc = (tm+tp)/2 + + done = nx.zeros((u_values.shape[0], m)) + + cpt = 0 + while nx.any(1-done): + cpt += 1 + + dCp, dCm = dCost(tc, u_values, v_values, u_cdf, v_cdf, p) + done = ((dCp*dCm)<=0) * 1 + + mask = ((tp-tm)0.001) + tc[mask_end>0] = ((Ctp-Ctm+tm*dCptm-tp*dCmtp)/(dCptm-dCmtp))[mask_end>0] + done[nx.prod(mask, axis=-1)>0] = 1 + elif nx.any(1-done): + tm[((1-mask)*(dCp<0))>0] = tc[((1-mask)*(dCp<0))>0] + tp[((1-mask)*(dCp>=0))>0] = tc[((1-mask)*(dCp>=0))>0] + tc[((1-mask)*(1-done))>0] = (tm[((1-mask)*(1-done))>0]+tp[((1-mask)*(1-done))>0])/2 + + return Cost(tc, u_values, v_values, u_cdf, v_cdf, p) + + +def w1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True): + r"""Computes the 1-Wasserstein distance on the circle using the level median. + + .. math: + W_1(\mu,\nu) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain + v_values : ndarray, shape (n, ...) + samples in the target domain + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + require_sort: bool, optional + If True, sort the values. + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> w1_circle(u.T, v.T) + array([0.1]) + + References + ---------- + .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + + Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + """ + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + + ## Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0) + + cdf_diff = nx.cumsum(nx.take_along_axis(nx.concatenate((u_weights, -v_weights),0),values_sorter,0),0) + cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0) + + values_sorted = nx.zero_pad(values_sorted, pad_width=[(0,1),(0,0)], value=1) + delta = values_sorted[1:,...]-values_sorted[:-1,...] + weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0) + + sum_weights = nx.cumsum(weight_sorted, axis=0)-0.5 + sum_weights[sum_weights<0] = np.inf + inds = nx.argmin(sum_weights, axis=0) + + levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1,-1)), 0) + + return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0) + + +def w_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, + Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): + r"""Computes the Wasserstein distance on the circle using either [1] for p=1 or + the Binary search algorithm proposed in [2] otherwise. + + .. math: + OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain + v_values : ndarray, shape (n, ...) + samples in the target domain + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + p : float, optional + Power p used for computing the Wasserstein distance + Lm : int, optional + Lower bound dC. For p>1. + Lp : int, optional + Upper bound dC. For p>1. + tm: float, optional + Lower bound theta. For p>1. + tp: float, optional + Upper bound theta. For p>1. + eps: float, optional + Stopping condition. For p>1. + require_sort: bool, optional + If True, sort the values. + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> w_circle(u.T, v.T) + array([0.1]) + + References + ---------- + .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if p==1: + return w1_circle(u_values, v_values, u_weights, v_weights, require_sort) + + return binary_search_circle(u_values, v_values, u_weights, v_weights, + p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps, + require_sort=require_sort) \ No newline at end of file diff --git a/ot/sliced.py b/ot/sliced.py index 20891a4a0..fa7e833e1 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -107,7 +107,6 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, p=2, -------- >>> n_samples_a = 20 - >>> reg = 0.1 >>> X = np.random.normal(0., 1., (n_samples_a, 5)) >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE 0.0 @@ -208,7 +207,6 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, -------- >>> n_samples_a = 20 - >>> reg = 0.1 >>> X = np.random.normal(0., 1., (n_samples_a, 5)) >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE 0.0 @@ -258,3 +256,94 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, if log: return res, {"projections": projections, "projected_emds": projected_emd} return res + + +def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, + p=2, seed=None, log=False): + r""" + Compute the sliced-Wasserstein distance on the sphere. + + .. math:: + SSW_p(\mu,\nu) = \left(\int_{\mathbb{V}_{d,2}} W_p^p(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right^{\frac{1}{p}} + + where: + - :math: `P^U_\#\mu` stands for the pushforwards of the projection :math: `\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}` + + Parameters: + ----------- + X_s: ndarray, shape (n_samples_a, dim) + Samples in the source domain + X_t: ndarray, shape (n_samples_b, dim) + Samples in the target domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional = + Power p used for computing the spherical sliced Wasserstein + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Spherical Sliced Wasserstein Cost + log : dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + >>> n_samples_a = 20 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> X = X / np.sqrt(np.sum(X**2, -1, keepdims=True)) + >>> sliced_wasserstein_sphere(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0.0 + + References + ---------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2022). Spherical sliced-wasserstein. Interna- +tional Conference on Learning Representations. + """ + from .lp import w_circle + + if a is not None and b is not None: + nx = get_backend(X_s, X_t, a, b) + else: + nx = get_backend(X_s, X_t) + + n, d = X_s.shape + m, _ = X_t.shape + + ## Uniforms and independent samples on the Stiefel manifold V_{d,2} + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + Z = seed.randn(n_projections, d, 2) + else: + if seed is not None: + nx.seed(seed) + Z = nx.randn(n_projections, d, 2) + + projections, _ = nx.qr(Z) + + ## Projection on S^1 + ## Projection on plane + Xps = nx.reshape(nx.dot(nx.transpose(projections,1,2)[:,None], X_s[:,:,None]),(n_projections, n, 2)) + Xpt = nx.reshape(nx.dot(nx.transpose(projections,1,2)[:,None], X_t[:,:,None]),(n_projections, m, 2)) + + ## Projection on sphere + Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) + Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True)) + + ## Get coords + Xps = (nx.atan2(-Xps[:,:,1], -Xps[:,:,0])+np.pi)/(2*np.pi) + Xpt = (nx.atan2(-Xpt[:,:,1], -Xpt[:,:,0])+np.pi)/(2*np.pi) + + projected_emd = w_circle(Xps.T, Xpt.T, u_weights=a, v_weights=b, p=p) + res = nx.mean(projected_emd) ** (1/p) + + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res From 88510d13ea8b69a22348f0fedce101f78b1dba66 Mon Sep 17 00:00:00 2001 From: Clement Date: Tue, 7 Feb 2023 18:12:36 +0100 Subject: [PATCH 02/16] Tests + Example SSW_1 --- README.md | 2 +- .../sliced-wasserstein/plot_variance_ssw.py | 156 +++++++++ ot/backend.py | 13 +- ot/lp/solver_1d.py | 310 ++++++++++-------- ot/sliced.py | 72 ++-- test/test_1d_solver.py | 60 ++++ test/test_sliced.py | 140 ++++++++ 7 files changed, 574 insertions(+), 179 deletions(-) create mode 100644 examples/sliced-wasserstein/plot_variance_ssw.py diff --git a/README.md b/README.md index 1d24bc358..7f571e1fa 100644 --- a/README.md +++ b/README.md @@ -299,4 +299,4 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. -[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2022). [Spherical sliced-wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. \ No newline at end of file +[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). [Spherical Sliced-Wasserstein](https://openreview.net/forum?id=jXQ0ipgMdU). International Conference on Learning Representations. \ No newline at end of file diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py new file mode 100644 index 000000000..b188f97cd --- /dev/null +++ b/examples/sliced-wasserstein/plot_variance_ssw.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +""" +==================================================== +Spherical Sliced Wasserstein on distributions in S^2 +==================================================== + +This example illustrates the computation of the spherical sliced Wasserstein discrepancy as +proposed in [46]. + +[46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). 'Spherical Sliced-Wasserstein". International Conference on Learning Representations. + +""" + +# Author: Clément Bonet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import matplotlib.pylab as pl +import numpy as np + +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 500 # nb samples + +xs = np.random.randn(n, 3) +xt = np.random.randn(n, 3) + +xs = xs / np.sqrt(np.sum(xs**2, -1, keepdims=True)) +xt = xt / np.sqrt(np.sum(xt**2, -1, keepdims=True)) + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +############################################################################## +# Plot data +# --------- + +# %% plot samples + +fig = pl.figure(figsize=(10, 10)) +ax = pl.axes(projection='3d') +ax.grid(False) + +u, v = np.mgrid[0:2 * np.pi:30j, 0:np.pi:30j] +x = np.cos(u) * np.sin(v) +y = np.sin(u) * np.sin(v) +z = np.cos(v) +ax.plot_surface(x, y, z, color="gray", alpha=0.03) +ax.plot_wireframe(x, y, z, linewidth=1, alpha=0.25, color="gray") + +ax.scatter(xs[:, 0], xs[:, 1], xs[:, 2], label="Source") +ax.scatter(xt[:, 0], xt[:, 1], xt[:, 2], label="Target") + +fs = 10 +# Labels +ax.set_xlabel('x', fontsize=fs) +ax.set_ylabel('y', fontsize=fs) +ax.set_zlabel('z', fontsize=fs) + +ax.view_init(20, 120) +ax.set_xlim(-1.5, 1.5) +ax.set_ylim(-1.5, 1.5) +ax.set_zlim(-1.5, 1.5) + +# Ticks +ax.set_xticks([-1, 0, 1]) +ax.set_yticks([-1, 0, 1]) +ax.set_zticks([-1, 0, 1]) + +pl.legend(loc=0) +pl.title("Source and Target distribution") + +############################################################################### +# Spherical Sliced Wasserstein for different seeds and number of projections +# -------------------------------------------------------------------------- + +n_seed = 50 +n_projections_arr = np.logspace(0, 3, 25, dtype=int) +res = np.empty((n_seed, 25)) + +# %% Compute statistics +for seed in range(n_seed): + for i, n_projections in enumerate(n_projections_arr): + res[seed, i] = ot.sliced_wasserstein_sphere(xs, xt, a, b, n_projections, seed=seed, p=1) + +res_mean = np.mean(res, axis=0) +res_std = np.std(res, axis=0) + +############################################################################### +# Plot Spherical Sliced Wasserstein +# --------------------------------- + +pl.figure(2) +pl.plot(n_projections_arr, res_mean, label=r"$SSW_1$") +pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) + +pl.legend() +pl.xscale('log') + +pl.xlabel("Number of projections") +pl.ylabel("Distance") +pl.title('Spherical Sliced Wasserstein Distance with 95% confidence inverval') + +pl.show() + +############################################################################### +# Spherical Sliced Wasserstein for different seeds and number of samples +# -------------------------------------------------------------------------- + +n_seed = 10 +n_projections = 500 +n_samples_array = np.logspace(1, 4, 10, dtype=int) +res = np.empty((n_seed, 10)) + +# %% Compute statistics +for seed in range(n_seed): + np.random.seed(seed) + for i, n_samples in enumerate(n_samples_array): + + xs = np.random.randn(n_samples, 3) + xt = np.random.randn(n_samples, 3) + + xs = xs / np.sqrt(np.sum(xs**2, -1, keepdims=True)) + xt = xt / np.sqrt(np.sum(xt**2, -1, keepdims=True)) + + res[seed, i] = ot.sliced_wasserstein_sphere(xs, xt, n_projections=n_projections, seed=seed, p=1) + +res_mean = np.mean(res, axis=0) +res_std = np.std(res, axis=0) + +############################################################################### +# Plot Spherical Sliced Wasserstein +# --------------------------------- + +pl.figure(2) +pl.plot(n_samples_array, res_mean, label=r"$SSW_1$") +pl.fill_between(n_samples_array, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) + +pl.legend() +pl.xscale('log') + +pl.xlabel("Number of samples") +pl.ylabel("Distance") +pl.title('Spherical Sliced Wasserstein Distance with 95% confidence inverval') + +pl.grid(True) +pl.show() + +# %% diff --git a/ot/backend.py b/ot/backend.py index cf332567d..ac5530a23 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -534,7 +534,7 @@ def concatenate(self, arrays, axis=0): """ raise NotImplementedError() - def zero_pad(self, a, pad_width, value): + def zero_pad(self, a, pad_width, value=0): r""" Pads a tensor. @@ -898,12 +898,11 @@ def is_floating_point(self, a): def tile(self, a, reps): r""" Construct an array by repeating a the number of times given by reps - + See: https://numpy.org/doc/stable/reference/generated/numpy.tile.html """ raise NotImplementedError() - def floor(self, a): r""" Return the floor of the input element-wise @@ -1261,7 +1260,7 @@ def atan2(self, a, b): return np.arctan2(a, b) def transpose(self, a, dim0, dim1): - return np.transpose(a, axes=[0,dim1,dim0]) + return np.transpose(a, axes=[0, dim1, dim0]) class JaxBackend(Backend): @@ -1429,7 +1428,7 @@ def take_along_axis(self, arr, indices, axis): def concatenate(self, arrays, axis=0): return jnp.concatenate(arrays, axis) - def zero_pad(self, a, value=0): + def zero_pad(self, a, pad_width, value=0): return jnp.pad(a, pad_width, constant_values=value) def argmax(self, a, axis=None): @@ -1608,7 +1607,7 @@ def atan2(self, a, b): return jnp.numpy.arctan2(a, b) def transpose(self, a, dim0, dim1): - return jnp.numpy.transpose(a, axes=[0,dim1,dim0]) + return jnp.numpy.transpose(a, axes=[0, dim1, dim0]) class TorchBackend(Backend): @@ -2795,7 +2794,7 @@ def floor(self, a): return tf.floor(a) def prod(self, a, axis): - return tf.experimental.numpy.prod(a, axis=axis) + return tnp.prod(a, axis=axis) def sort2(self, a, axis): return self.sort(a, axis), self.argsort(a, axis) diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index cf0d19868..268650aa7 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -367,7 +367,6 @@ def emd2_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, return cost - def roll_cols(M, shifts): r""" Utils functions which allow to shift the order of each row of a 2d matrix @@ -381,7 +380,7 @@ def roll_cols(M, shifts): Returns ------- Shifted array - + Examples -------- >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]]) @@ -389,7 +388,7 @@ def roll_cols(M, shifts): array([[2, 3, 1], [5, 6, 4], [8, 9, 7]]) - >>> roll_cols(M, np.array([[1],[2],[1]]))) + >>> roll_cols(M, np.array([[1],[2],[1]])) array([[3, 1, 2], [5, 6, 4], [9, 7, 8]]) @@ -397,21 +396,20 @@ def roll_cols(M, shifts): References ---------- https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch - """ nx = get_backend(M) n_rows, n_cols = M.shape - arange1 = nx.tile(nx.arange(n_cols).reshape((1, n_cols)), (n_rows,1)) + arange1 = nx.tile(nx.reshape(nx.arange(n_cols), (1, n_cols)), (n_rows, 1)) arange2 = (arange1 - shifts) % n_cols - + return nx.take_along_axis(M, arange2, 1) - + def dCost(theta, u_values, v_values, u_cdf, v_cdf, p=2): r""" Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1]) - + Parameters ---------- theta: array-like, shape (n_batch, n) @@ -426,66 +424,79 @@ def dCost(theta, u_values, v_values, u_cdf, v_cdf, p=2): cdf of the second empirical distribution p: float, optional = 2 Power p used for computing the Wasserstein distance - + Returns ------- dCp: array-like, shape (n_batch, 1) - The batched right derivative + The batched right derivative dCm: array-like, shape (n_batch, 1) - The batched left derivative - + The batched left derivative + References --------- .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. """ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) - + v_values = nx.copy(v_values) - + n = u_values.shape[-1] m_batch, m = v_values.shape - + v_cdf_theta = v_cdf - (theta - nx.floor(theta)) - - mask_p = v_cdf_theta>=0 - mask_n = v_cdf_theta<0 - - v_values[mask_n] += nx.floor(theta)[mask_n]+1 + + mask_p = v_cdf_theta >= 0 + mask_n = v_cdf_theta < 0 + + v_values[mask_n] += nx.floor(theta)[mask_n] + 1 v_values[mask_p] += nx.floor(theta)[mask_p] - + if nx.any(mask_n) and nx.any(mask_p): v_cdf_theta[mask_n] += 1 - + v_cdf_theta2 = nx.copy(v_cdf_theta) v_cdf_theta2[mask_n] = np.inf shift = (-nx.argmin(v_cdf_theta2, axis=-1)) - v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1,1))) - v_values = roll_cols(v_values, nx.reshape(shift, (-1,1))) - v_values = nx.concatenate([v_values, nx.reshape(v_values[:,0],(-1,1))+1], axis=1) - - ## quantiles of F_u evaluated in F_v^\theta + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) + v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) + + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + u_cdf = u_cdf.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + + # quantiles of F_u evaluated in F_v^\theta u_index = nx.searchsorted(u_cdf, v_cdf_theta) - u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n-1), -1) - - ## Deal with 1 - u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:,0], (-1,1))+1], axis=1) - u_valuesm = nx.concatenate([u_values, nx.reshape(u_values[:,0], (-1,1))+1], axis=1) + u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1) + + # Deal with 1 + u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1) + u_valuesm = nx.concatenate([u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1) + + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + u_cdfm = u_cdfm.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right") u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1) - - dCp = nx.sum(nx.power(nx.abs(u_icdf_theta-v_values[:,1:]), p) - -nx.power(nx.abs(u_icdf_theta-v_values[:,:-1]), p), axis=-1) - - dCm = nx.sum(nx.power(nx.abs(u_icdfm_theta-v_values[:,1:]), p) - -nx.power(nx.abs(u_icdfm_theta-v_values[:,:-1]), p), axis=-1) - - return dCp.reshape(-1,1), dCm.reshape(-1,1) + + dCp = nx.sum(nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), axis=-1) + + dCm = nx.sum(nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), axis=-1) + + return dCp.reshape(-1, 1), dCm.reshape(-1, 1) def Cost(theta, u_values, v_values, u_cdf, v_cdf, p): r""" Computes the the cost (Equation (6.2) of [1]) - + Parameters ---------- theta: array-like, shape (n_batch, n) @@ -500,71 +511,82 @@ def Cost(theta, u_values, v_values, u_cdf, v_cdf, p): cdf of the second empirical distribution p: float, optional = 2 Power p used for computing the Wasserstein distance - + Returns ------- ot_cost: array-like, shape (n_batch,) - OT cost evaluated at theta - + OT cost evaluated at theta + References --------- .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. """ nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) - + v_values = nx.copy(v_values) - + m_batch, m = v_values.shape n_batch, n = u_values.shape - v_cdf_theta = v_cdf -(theta - nx.floor(theta)) - - mask_p = v_cdf_theta>=0 - mask_n = v_cdf_theta<0 - - v_values[mask_n] += nx.floor(theta)[mask_n]+1 + v_cdf_theta = v_cdf - (theta - nx.floor(theta)) + + mask_p = v_cdf_theta >= 0 + mask_n = v_cdf_theta < 0 + + v_values[mask_n] += nx.floor(theta)[mask_n] + 1 v_values[mask_p] += nx.floor(theta)[mask_p] - + if nx.any(mask_n) and nx.any(mask_p): v_cdf_theta[mask_n] += 1 - - ## Put negative values at the end + + # Put negative values at the end v_cdf_theta2 = nx.copy(v_cdf_theta) v_cdf_theta2[mask_n] = np.inf - shift = (-nx.argmin(v_cdf_theta2, axis=-1))# .tolist() + shift = (-nx.argmin(v_cdf_theta2, axis=-1)) + + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) + v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) - v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1,1))) - v_values = roll_cols(v_values, nx.reshape(shift, (-1,1))) - v_values = nx.concatenate([v_values, nx.reshape(v_values[:,0], (-1,1))+1], axis=1) - - ## Compute absciss + # Compute absciss cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1) - cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0,0),(1,0)]) - + cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)]) + delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1] - - ## Compute icdf + + if nx.__name__ == 'torch': + # this is to ensure the best performance for torch searchsorted + # and avoid a warninng related to non-contiguous arrays + u_cdf = u_cdf.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + cdf_axis = cdf_axis.contiguous() + + # Compute icdf u_index = nx.searchsorted(u_cdf, cdf_axis) - u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n-1), -1) - - v_values = nx.concatenate([v_values, nx.reshape(v_values[:,0], (-1,1))+1], axis=1) + u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1) + + v_values = nx.concatenate([v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1) v_index = nx.searchsorted(v_cdf_theta, cdf_axis) v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1) - + if p == 1: - ot_cost = nx.sum(delta*nx.abs(u_icdf-v_icdf), axis=-1) + ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1) else: - ot_cost = nx.sum(delta*nx.power(nx.abs(u_icdf-v_icdf), p), axis=-1) + ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1) return ot_cost -def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, +def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): - r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [1]. + r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. - .. math: - OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + .. math:: + W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + + where: + + - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v` Parameters ---------- @@ -576,7 +598,7 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 samples weights in the source domain v_weights : ndarray, shape (n, ...), optional samples weights in the target domain - p : float, optional + p : float, optional (default=1) Power p used for computing the Wasserstein distance Lm : int, optional Lower bound dC @@ -590,7 +612,7 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 Stopping condition require_sort: bool, optional If True, sort the values. - + Examples -------- >>> u = np.array([[0.2,0.5,0.8]])%1 @@ -601,19 +623,23 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 References ---------- .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. - - Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html + .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html """ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) - + if u_weights is not None and v_weights is not None: nx = get_backend(u_values, v_values, u_weights, v_weights) else: nx = get_backend(u_values, v_values) - + n = u_values.shape[0] m = v_values.shape[0] - + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + if u_weights is None: u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: @@ -622,7 +648,7 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) - + if require_sort: u_sorter = nx.argsort(u_values, 0) u_values = nx.take_along_axis(u_values, u_sorter, 0) @@ -632,57 +658,56 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 u_weights = nx.take_along_axis(u_weights, u_sorter, 0) v_weights = nx.take_along_axis(v_weights, v_sorter, 0) - + u_cdf = nx.cumsum(u_weights, 0).T v_cdf = nx.cumsum(v_weights, 0).T - + u_values = u_values.T v_values = v_values.T - - + L = max(Lm, Lp) - - tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1,1)) + + tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) tm = nx.tile(tm, (1, m)) - tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1,1)) + tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) tp = nx.tile(tp, (1, m)) - tc = (tm+tp)/2 - + tc = (tm + tp) / 2 + done = nx.zeros((u_values.shape[0], m)) - + cpt = 0 - while nx.any(1-done): + while nx.any(1 - done): cpt += 1 - + dCp, dCm = dCost(tc, u_values, v_values, u_cdf, v_cdf, p) - done = ((dCp*dCm)<=0) * 1 - - mask = ((tp-tm)0.001) - tc[mask_end>0] = ((Ctp-Ctm+tm*dCptm-tp*dCmtp)/(dCptm-dCmtp))[mask_end>0] - done[nx.prod(mask, axis=-1)>0] = 1 - elif nx.any(1-done): - tm[((1-mask)*(dCp<0))>0] = tc[((1-mask)*(dCp<0))>0] - tp[((1-mask)*(dCp>=0))>0] = tc[((1-mask)*(dCp>=0))>0] - tc[((1-mask)*(1-done))>0] = (tm[((1-mask)*(1-done))>0]+tp[((1-mask)*(1-done))>0])/2 - + + mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) + tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0] + done[nx.prod(mask, axis=-1) > 0] = 1 + elif nx.any(1 - done): + tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0] + tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0] + tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2 + return Cost(tc, u_values, v_values, u_cdf, v_cdf, p) def w1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True): - r"""Computes the 1-Wasserstein distance on the circle using the level median. + r"""Computes the 1-Wasserstein distance on the circle using the level median [45]. + + .. math:: + W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t - .. math: - W_1(\mu,\nu) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t - Parameters ---------- u_values : ndarray, shape (n, ...) @@ -695,7 +720,7 @@ def w1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=T samples weights in the target domain require_sort: bool, optional If True, sort the values. - + Examples -------- >>> u = np.array([[0.2,0.5,0.8]])%1 @@ -706,18 +731,22 @@ def w1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=T References ---------- .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. - - Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ """ - + if u_weights is not None and v_weights is not None: nx = get_backend(u_values, v_values, u_weights, v_weights) else: nx = get_backend(u_values, v_values) - + n = u_values.shape[0] m = v_values.shape[0] - + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + if u_weights is None: u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: @@ -726,7 +755,7 @@ def w1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=T v_weights = nx.full(v_values.shape, 1. / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) - + if require_sort: u_sorter = nx.argsort(u_values, 0) u_values = nx.take_along_axis(u_values, u_sorter, 0) @@ -737,30 +766,29 @@ def w1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=T u_weights = nx.take_along_axis(u_weights, u_sorter, 0) v_weights = nx.take_along_axis(v_weights, v_sorter, 0) - - ## Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0) - - cdf_diff = nx.cumsum(nx.take_along_axis(nx.concatenate((u_weights, -v_weights),0),values_sorter,0),0) + + cdf_diff = nx.cumsum(nx.take_along_axis(nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0), 0) cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0) - - values_sorted = nx.zero_pad(values_sorted, pad_width=[(0,1),(0,0)], value=1) - delta = values_sorted[1:,...]-values_sorted[:-1,...] + + values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1) + delta = values_sorted[1:, ...] - values_sorted[:-1, ...] weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0) - - sum_weights = nx.cumsum(weight_sorted, axis=0)-0.5 - sum_weights[sum_weights<0] = np.inf + + sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5 + sum_weights[sum_weights < 0] = np.inf inds = nx.argmin(sum_weights, axis=0) - levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1,-1)), 0) + levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0) return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0) -def w_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, +def w_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): - r"""Computes the Wasserstein distance on the circle using either [1] for p=1 or - the Binary search algorithm proposed in [2] otherwise. + r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or + the binary search algorithm proposed in [44] otherwise. .. math: OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q @@ -775,7 +803,7 @@ def w_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, samples weights in the source domain v_weights : ndarray, shape (n, ...), optional samples weights in the target domain - p : float, optional + p : float, optional (default=1) Power p used for computing the Wasserstein distance Lm : int, optional Lower bound dC. For p>1. @@ -789,7 +817,7 @@ def w_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, Stopping condition. For p>1. require_sort: bool, optional If True, sort the values. - + Examples -------- >>> u = np.array([[0.2,0.5,0.8]])%1 @@ -803,10 +831,10 @@ def w_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. """ assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) - - if p==1: + + if p == 1: return w1_circle(u_values, v_values, u_weights, v_weights, require_sort) - return binary_search_circle(u_values, v_values, u_weights, v_weights, - p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps, - require_sort=require_sort) \ No newline at end of file + return binary_search_circle(u_values, v_values, u_weights, v_weights, + p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps, + require_sort=require_sort) diff --git a/ot/sliced.py b/ot/sliced.py index fa7e833e1..f51fb9cf8 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -259,15 +259,17 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, - p=2, seed=None, log=False): + p=2, seed=None, log=False): r""" Compute the sliced-Wasserstein distance on the sphere. - + .. math:: - SSW_p(\mu,\nu) = \left(\int_{\mathbb{V}_{d,2}} W_p^p(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right^{\frac{1}{p}} - + SSW_p(\mu,\nu) = \left(\int_{\mathbb{V}_{d,2}} W_p^p(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right)^{\frac{1}{p}} + where: - - :math: `P^U_\#\mu` stands for the pushforwards of the projection :math: `\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}` + + - :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}` + Parameters: ----------- @@ -281,7 +283,7 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, samples weights in the target domain n_projections : int, optional Number of projections used for the Monte-Carlo approximation - p: float, optional = + p: float, optional (default=2) Power p used for computing the spherical sliced Wasserstein seed: int or RandomState or None, optional Seed used for random number generator @@ -294,7 +296,7 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, Spherical Sliced Wasserstein Cost log : dict, optional log dictionary return only if log==True in parameters - + Examples -------- >>> n_samples_a = 20 @@ -305,45 +307,55 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, References ---------- - .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2022). Spherical sliced-wasserstein. Interna- -tional Conference on Learning Representations. + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. """ from .lp import w_circle - + if a is not None and b is not None: nx = get_backend(X_s, X_t, a, b) else: nx = get_backend(X_s, X_t) - + n, d = X_s.shape m, _ = X_t.shape - - ## Uniforms and independent samples on the Stiefel manifold V_{d,2} + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) + if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("X_s is not on the sphere.") + if nx.any(nx.abs(nx.sum(X_t**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("Xt is not on the sphere.") + + # Uniforms and independent samples on the Stiefel manifold V_{d,2} if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': Z = seed.randn(n_projections, d, 2) else: if seed is not None: nx.seed(seed) - Z = nx.randn(n_projections, d, 2) - + Z = nx.randn(n_projections, d, 2, type_as=X_s) + projections, _ = nx.qr(Z) - - ## Projection on S^1 - ## Projection on plane - Xps = nx.reshape(nx.dot(nx.transpose(projections,1,2)[:,None], X_s[:,:,None]),(n_projections, n, 2)) - Xpt = nx.reshape(nx.dot(nx.transpose(projections,1,2)[:,None], X_t[:,:,None]),(n_projections, m, 2)) - - ## Projection on sphere + + # Projection on S^1 + # Projection on plane + # Xps = nx.reshape(nx.dot(nx.transpose(projections, 1, 2)[:, None], X_s[:, :, None]), (n_projections, n, 2)) ## numpy reshapes in the wrong order + # Xpt = nx.reshape(nx.dot(nx.transpose(projections, 1, 2)[:, None], X_t[:, :, None]), (n_projections, m, 2)) + Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, 1, 2)[:, None], X_s[:, :, None]), (n_projections, 2, n)), 1, 2) + Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, 1, 2)[:, None], X_t[:, :, None]), (n_projections, 2, m)), 1, 2) + + # Projection on sphere Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True)) - - ## Get coords - Xps = (nx.atan2(-Xps[:,:,1], -Xps[:,:,0])+np.pi)/(2*np.pi) - Xpt = (nx.atan2(-Xpt[:,:,1], -Xpt[:,:,0])+np.pi)/(2*np.pi) - - projected_emd = w_circle(Xps.T, Xpt.T, u_weights=a, v_weights=b, p=p) - res = nx.mean(projected_emd) ** (1/p) - + + # Get coords + Xps_coords = (nx.atan2(-Xps[:, :, 1], -Xps[:, :, 0]) + np.pi) / (2 * np.pi) + Xpt_coords = (nx.atan2(-Xpt[:, :, 1], -Xpt[:, :, 0]) + np.pi) / (2 * np.pi) + + projected_emd = w_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p) + res = nx.mean(projected_emd) ** (1 / p) + if log: return res, {"projections": projections, "projected_emds": projected_emd} return res diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 20f307a45..36e73a052 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -218,3 +218,63 @@ def test_emd1d_device_tf(): nx.assert_same_dtype_device(xb, emd) nx.assert_same_dtype_device(xb, emd2) assert nx.dtype_device(emd)[1].startswith("GPU") + + +def test_wasserstein_1d_circle(): + # test binary_search_circle and w1_circle give similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + w_u = rng.uniform(0., 1., n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0., 1., m) + w_v = w_v / w_v.sum() + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + + wass1 = ot.emd2(w_u, w_v, M1) + + wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1) + wass1_circle = ot.w1_circle(u, v, w_u, w_v) + w1_circle = ot.w_circle(u, v, w_u, w_v, p=1) + + M2 = M1**2 + wass2 = ot.emd2(w_u, w_v, M2) + wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2) + w2_circle = ot.w_circle(u, v, w_u, w_v, p=2) + + # check loss is similar + np.testing.assert_allclose(wass1, wass1_bsc) + + np.testing.assert_allclose(wass1, wass1_circle, rtol=1e-2) + np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2) + + np.testing.assert_allclose(wass2, wass2_bsc) + np.testing.assert_allclose(wass2, w2_circle) + + +@pytest.skip_backend("tf") +def test_wasserstein1d_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + w1 = ot.w_circle(xb, xb, rho_ub, rho_vb, p=1) + w2_bsc = ot.w_circle(xb, xb, rho_ub, rho_vb, p=2) + + nx.assert_same_dtype_device(xb, w1) + nx.assert_same_dtype_device(xb, w2_bsc) diff --git a/test/test_sliced.py b/test/test_sliced.py index eb13469ae..464a6c0b4 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -266,3 +266,143 @@ def test_max_sliced_backend_device_tf(): valb = ot.max_sliced_wasserstein_distance(xb, yb, projections=Pb) nx.assert_same_dtype_device(xb, valb) assert nx.dtype_device(valb)[1].startswith("GPU") + + +def test_projections_stiefel(): + rng = np.random.RandomState(0) + + n_projs = 500 + x = np.random.randn(100, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + ssw, log = ot.sliced_wasserstein_sphere(x, x, n_projections=n_projs, + seed=rng, log=True) + + P = log["projections"] + P_T = np.transpose(P, [0, 2, 1]) + np.testing.assert_almost_equal(np.matmul(P_T, P), np.array([np.eye(2) for k in range(n_projs)])) + + +def test_sliced_sphere_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res = ot.sliced_wasserstein_sphere(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_sliced_sphere_bad_shapes(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + +def test_sliced_sphere(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(n, 4) + + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + + +def test_sliced_sphere_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + y = rng.randn(n, 4) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_sphere(x, y, u, u, 10, p=1, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[0] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_sphere_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + u = ot.utils.unif(n) + y = rng.randn(n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + res = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) + assert res > 0. + + +def test_1d_sliced_sphere_equals_emd(): + n = 100 + m = 120 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + x_coords = (np.arctan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi) + a = rng.uniform(0, 1, n) + a /= a.sum() + + y = rng.randn(m, 2) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + y_coords = (np.arctan2(-y[:, 1], -y[:, 0]) + np.pi) / (2 * np.pi) + u = ot.utils.unif(m) + + res = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=2) + expected = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=2) + + res1 = ot.sliced_wasserstein_sphere(x, y, a, u, 10, seed=42, p=1) + expected1 = ot.binary_search_circle(x_coords.T, y_coords.T, a, u, p=1) + + np.testing.assert_almost_equal(res ** 2, expected) + np.testing.assert_almost_equal(res1, expected1, decimal=3) + + +@pytest.skip_backend("tf") +def test_sliced_sphere_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + y = rng.randn(2 * n, 3) + y = y / np.sqrt(np.sum(y**2, -1, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, yb = nx.from_numpy(x, y, type_as=tp) + + valb = ot.sliced_wasserstein_sphere(xb, yb) + + nx.assert_same_dtype_device(xb, valb) From 0f87a5793d5b17352523d7b49c69b563cbc31efa Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 8 Feb 2023 14:46:39 +0100 Subject: [PATCH 03/16] Example Wasserstein Circle + Tests --- examples/plot_compute_w1_circle.py | 127 ++++++++++++++++++ .../sliced-wasserstein/plot_variance_ssw.py | 45 ------- ot/__init__.py | 4 +- ot/backend.py | 51 ++++--- ot/lp/__init__.py | 4 +- ot/lp/solver_1d.py | 23 +++- ot/sliced.py | 4 +- test/test_1d_solver.py | 12 +- test/test_backend.py | 46 +++++++ 9 files changed, 231 insertions(+), 85 deletions(-) create mode 100644 examples/plot_compute_w1_circle.py diff --git a/examples/plot_compute_w1_circle.py b/examples/plot_compute_w1_circle.py new file mode 100644 index 000000000..801411fad --- /dev/null +++ b/examples/plot_compute_w1_circle.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +""" +========================= +OT distance on the Circle +========================= + +Shows how to compute the Wasserstein distance on the circle + + +""" + +# Author: Clément Bonet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +import ot + +from scipy.special import iv + +############################################################################## +# Plot data +# --------- + +#%% plot the distributions + + +def pdf_von_Mises(theta, mu, kappa): + pdf = np.exp(kappa * np.cos(theta - mu)) / (2.0 * np.pi * iv(0, kappa)) + return pdf + + +t = np.linspace(0, 2 * np.pi, 1000, endpoint=False) + +mu1 = 1 +kappa1 = 20 + +mu_targets = np.linspace(mu1, mu1 + 2 * np.pi, 10) + + +pdf1 = pdf_von_Mises(t, mu1, kappa1) + + +pl.figure(1) +for k, mu in enumerate(mu_targets): + pdf_t = pdf_von_Mises(t, mu, kappa1) + if k == 0: + label = "Source distributions" + else: + label = None + pl.plot(t / (2 * np.pi), pdf_t, c='b', label=label) + +pl.plot(t / (2 * np.pi), pdf1, c="r", label="Target distribution") +pl.legend() + +mu2 = 0 +kappa2 = kappa1 + +x1 = np.random.vonmises(mu1, kappa1, size=(10,)) + np.pi +x2 = np.random.vonmises(mu2, kappa2, size=(10,)) + np.pi + +angles = np.linspace(0, 2 * np.pi, 150) + +pl.figure(2) +pl.plot(np.cos(angles), np.sin(angles), c="k") +pl.xlim(-1.25, 1.25) +pl.ylim(-1.25, 1.25) +pl.scatter(np.cos(x1), np.sin(x1), c="b") +pl.scatter(np.cos(x2), np.sin(x2), c="r") + +######################################################################################### +# Compare the Euclidean Wasserstein distance with the Wasserstein distance on the circle +# --------------------------------------------------------------------------------------- +# This examples illustrates the periodicity of the Wasserstein distance on the circle. +# We choose as target distribution a von Mises distribution with mean :math:`\mu_{\mathrm{target}}` +# and :math:`\kappa=20`. Then, we compare the distances with samples obtained from a von Mises distribution +# with parameters :math:`\mu_{\mathrm{source}}` and :math:`\kappa=20`. +# The Wasserstein distance on the circle takes into account the periodicity +# and attains its maximum in :math:`\mu_{\mathrm{target}}+1` (the antipodal point) contrary to the +# Euclidean version. + +#%% Compute and plot distributions + +mu_targets = np.linspace(0, 2 * np.pi, 200) +xs = np.random.vonmises(mu1 - np.pi, kappa1, size=(500,)) + np.pi + +n_try = 5 + +xts = np.zeros((n_try, 200, 500)) +for i in range(n_try): + for k, mu in enumerate(mu_targets): + # np.random.vonmises deals with data on [-pi, pi[ + xt = np.random.vonmises(mu - np.pi, kappa2, size=(500,)) + np.pi + xts[i, k] = xt + +# Put data on S^1=[0,1[ +xts2 = xts / (2 * np.pi) +xs2 = np.concatenate([xs[None] for k in range(200)], axis=0) / (2 * np.pi) + +L_w2_circle = np.zeros((n_try, 200)) +L_w2 = np.zeros((n_try, 200)) + +for i in range(n_try): + w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2) + w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2) + + L_w2_circle[i] = w2_circle + L_w2[i] = w2 + +m_w2_circle = np.mean(L_w2_circle, axis=0) +std_w2_circle = np.std(L_w2_circle, axis=0) + +m_w2 = np.mean(L_w2, axis=0) +std_w2 = np.std(L_w2, axis=0) + +pl.figure(1) +pl.plot(mu_targets / (2 * np.pi), m_w2_circle, label="Wasserstein circle") +pl.fill_between(mu_targets / (2 * np.pi), m_w2_circle - 2 * std_w2_circle, m_w2_circle + 2 * std_w2_circle, alpha=0.5) +pl.plot(mu_targets / (2 * np.pi), m_w2, label="Euclidean Wasserstein") +pl.fill_between(mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5) +pl.vlines(x=[mu1 / (2 * np.pi)], ymin=0, ymax=np.max(w2), linestyle="--", color="k", label=r"$\mu_{\mathrm{target}}$") +pl.legend() +pl.xlabel(r"$\mu_{\mathrm{source}}$") +pl.show() diff --git a/examples/sliced-wasserstein/plot_variance_ssw.py b/examples/sliced-wasserstein/plot_variance_ssw.py index b188f97cd..83d458f72 100644 --- a/examples/sliced-wasserstein/plot_variance_ssw.py +++ b/examples/sliced-wasserstein/plot_variance_ssw.py @@ -109,48 +109,3 @@ pl.title('Spherical Sliced Wasserstein Distance with 95% confidence inverval') pl.show() - -############################################################################### -# Spherical Sliced Wasserstein for different seeds and number of samples -# -------------------------------------------------------------------------- - -n_seed = 10 -n_projections = 500 -n_samples_array = np.logspace(1, 4, 10, dtype=int) -res = np.empty((n_seed, 10)) - -# %% Compute statistics -for seed in range(n_seed): - np.random.seed(seed) - for i, n_samples in enumerate(n_samples_array): - - xs = np.random.randn(n_samples, 3) - xt = np.random.randn(n_samples, 3) - - xs = xs / np.sqrt(np.sum(xs**2, -1, keepdims=True)) - xt = xt / np.sqrt(np.sum(xt**2, -1, keepdims=True)) - - res[seed, i] = ot.sliced_wasserstein_sphere(xs, xt, n_projections=n_projections, seed=seed, p=1) - -res_mean = np.mean(res, axis=0) -res_std = np.std(res, axis=0) - -############################################################################### -# Plot Spherical Sliced Wasserstein -# --------------------------------- - -pl.figure(2) -pl.plot(n_samples_array, res_mean, label=r"$SSW_1$") -pl.fill_between(n_samples_array, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) - -pl.legend() -pl.xscale('log') - -pl.xlabel("Number of samples") -pl.ylabel("Distance") -pl.title('Spherical Sliced Wasserstein Distance with 95% confidence inverval') - -pl.grid(True) -pl.show() - -# %% diff --git a/ot/__init__.py b/ot/__init__.py index c109d1425..95e36b9f4 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -38,7 +38,7 @@ from . import gaussian # OT functions -from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d, binary_search_circle, w1_circle, w_circle +from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d, binary_search_circle, wasserstein1_circle, wasserstein_circle from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) @@ -65,4 +65,4 @@ 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'factored_optimal_transport', 'solve', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', - 'binary_search_circle', 'w1_circle', 'w_circle'] + 'binary_search_circle', 'wasserstein1_circle', 'wasserstein_circle'] diff --git a/ot/backend.py b/ot/backend.py index ac5530a23..1273f8bfe 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -911,7 +911,7 @@ def floor(self, a): """ raise NotImplementedError() - def prod(self, a, axis): + def prod(self, a, axis=None): r""" Return the product of all elements. @@ -1247,10 +1247,10 @@ def tile(self, a, reps): def floor(self, a): return np.floor(a) - def prod(self, a, axis): + def prod(self, a, axis=0): return np.prod(a, axis=axis) - def sort2(self, a, axis): + def sort2(self, a, axis=-1): return self.sort(a, axis), self.argsort(a, axis) def qr(self, a): @@ -1260,7 +1260,10 @@ def atan2(self, a, b): return np.arctan2(a, b) def transpose(self, a, dim0, dim1): - return np.transpose(a, axes=[0, dim1, dim0]) + dims = list(range(len(a.shape))) + dims[dim0], dims[dim1] = dim1, dim0 + return np.transpose(a, axes=dims) + # return np.transpose(a, axes=[0, dim1, dim0]) class JaxBackend(Backend): @@ -1589,25 +1592,27 @@ def is_floating_point(self, a): return a.dtype.kind == "f" def tile(self, a, reps): - return jnp.numpy.tile(a, reps) + return jnp.tile(a, reps) def floor(self, a): - return jnp.numpy.floor(a) + return jnp.floor(a) - def prod(self, a, axis): - return jnp.numpy.prod(a, axis=axis) + def prod(self, a, axis=0): + return jnp.prod(a, axis=axis) - def sort2(self, a, axis): + def sort2(self, a, axis=-1): return self.sort(a, axis), self.argsort(a, axis) def qr(self, a): - return jnp.numpy.linalg.qr(a) + return jnp.linalg.qr(a) def atan2(self, a, b): - return jnp.numpy.arctan2(a, b) + return jnp.arctan2(a, b) def transpose(self, a, dim0, dim1): - return jnp.numpy.transpose(a, axes=[0, dim1, dim0]) + dims = list(range(len(a.shape))) + dims[dim0], dims[dim1] = dim1, dim0 + return jnp.transpose(a, axes=dims) class TorchBackend(Backend): @@ -2038,10 +2043,10 @@ def tile(self, a, reps): def floor(self, a): return torch.floor(a) - def prod(self, a, axis): + def prod(self, a, axis=0): return torch.prod(a, dim=axis) - def sort2(self, a, axis): + def sort2(self, a, axis=-1): return torch.sort(a, axis) def qr(self, a): @@ -2409,10 +2414,10 @@ def tile(self, a, reps): def floor(self, a): return cp.floor(a) - def prod(self, a, axis): + def prod(self, a, axis=0): return cp.prod(a, axis=axis) - def sort2(self, a, axis): + def sort2(self, a, axis=-1): return self.sort(a, axis), self.argsort(a, axis) def qr(self, a): @@ -2422,7 +2427,9 @@ def atan2(self, a, b): return cp.arctan2(a, b) def transpose(self, a, dim0, dim1): - return cp.transpose(a, axes=[0, dim1, dim0]) + dims = list(range(len(a.shape))) + dims[dim0], dims[dim1] = dim1, dim0 + return cp.transpose(a, axes=dims) class TensorflowBackend(Backend): @@ -2788,15 +2795,15 @@ def is_floating_point(self, a): return a.dtype.is_floating def tile(self, a, reps): - return tf.tile(a, reps) + return tnp.tile(a, reps) def floor(self, a): return tf.floor(a) - def prod(self, a, axis): + def prod(self, a, axis=0): return tnp.prod(a, axis=axis) - def sort2(self, a, axis): + def sort2(self, a, axis=-1): return self.sort(a, axis), self.argsort(a, axis) def qr(self, a): @@ -2806,4 +2813,6 @@ def atan2(self, a, b): return tf.math.atan2(a, b) def transpose(self, a, dim0, dim1): - return tf.transpose(a, perm=[0, dim1, dim0]) \ No newline at end of file + dims = list(range(len(a.shape))) + dims[dim0], dims[dim1] = dim1, dim0 + return tf.transpose(a, perm=dims) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index b3a0f803c..d6fe123d7 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -20,7 +20,7 @@ # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from .solver_1d import emd_1d, emd2_1d, wasserstein_1d, binary_search_circle, w1_circle, w_circle +from .solver_1d import emd_1d, emd2_1d, wasserstein_1d, binary_search_circle, wasserstein1_circle, wasserstein_circle from ..utils import dist, list_to_array from ..utils import parmap @@ -28,7 +28,7 @@ __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter', - 'binary_search_circle', 'w1_circle', 'w_circle'] + 'binary_search_circle', 'wasserstein1_circle', 'wasserstein_circle'] def check_number_threads(numThreads): diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 268650aa7..8fc50bec3 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -579,7 +579,8 @@ def Cost(theta, u_values, v_values, u_cdf, v_cdf, p): def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): - r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. + r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. + Samples need to be in :math:`S^1\cong [0,1[`. .. math:: W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q @@ -640,6 +641,9 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 if len(v_values.shape) == 1: v_values = nx.reshape(v_values, (m, 1)) + u_values = u_values % 1 + v_values = v_values % 1 + if u_weights is None: u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: @@ -702,8 +706,9 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 return Cost(tc, u_values, v_values, u_cdf, v_cdf, p) -def w1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True): +def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True): r"""Computes the 1-Wasserstein distance on the circle using the level median [45]. + Samples need to be in :math:`S^1\cong [0,1[`. .. math:: W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t @@ -725,7 +730,7 @@ def w1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=T -------- >>> u = np.array([[0.2,0.5,0.8]])%1 >>> v = np.array([[0.4,0.5,0.7]])%1 - >>> w1_circle(u.T, v.T) + >>> wasserstein1_circle(u.T, v.T) array([0.1]) References @@ -747,6 +752,9 @@ def w1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=T if len(v_values.shape) == 1: v_values = nx.reshape(v_values, (m, 1)) + u_values = u_values % 1 + v_values = v_values % 1 + if u_weights is None: u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: @@ -785,10 +793,11 @@ def w1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=T return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0) -def w_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, - Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): +def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, + Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or the binary search algorithm proposed in [44] otherwise. + Samples need to be in :math:`S^1\cong [0,1[`. .. math: OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q @@ -822,7 +831,7 @@ def w_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, -------- >>> u = np.array([[0.2,0.5,0.8]])%1 >>> v = np.array([[0.4,0.5,0.7]])%1 - >>> w_circle(u.T, v.T) + >>> wasserstein_circle(u.T, v.T) array([0.1]) References @@ -833,7 +842,7 @@ def w_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) if p == 1: - return w1_circle(u_values, v_values, u_weights, v_weights, require_sort) + return wasserstein1_circle(u_values, v_values, u_weights, v_weights, require_sort) return binary_search_circle(u_values, v_values, u_weights, v_weights, p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps, diff --git a/ot/sliced.py b/ot/sliced.py index f51fb9cf8..aa21c9315 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -309,7 +309,7 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, ---------- .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. """ - from .lp import w_circle + from .lp import wasserstein_circle if a is not None and b is not None: nx = get_backend(X_s, X_t, a, b) @@ -353,7 +353,7 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, Xps_coords = (nx.atan2(-Xps[:, :, 1], -Xps[:, :, 0]) + np.pi) / (2 * np.pi) Xpt_coords = (nx.atan2(-Xpt[:, :, 1], -Xpt[:, :, 0]) + np.pi) / (2 * np.pi) - projected_emd = w_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p) + projected_emd = wasserstein_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p) res = nx.mean(projected_emd) ** (1 / p) if log: diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 36e73a052..593727d40 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -221,7 +221,7 @@ def test_emd1d_device_tf(): def test_wasserstein_1d_circle(): - # test binary_search_circle and w1_circle give similar results as emd + # test binary_search_circle and wasserstein_circle give similar results as emd n = 20 m = 30 rng = np.random.RandomState(0) @@ -239,13 +239,13 @@ def test_wasserstein_1d_circle(): wass1 = ot.emd2(w_u, w_v, M1) wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1) - wass1_circle = ot.w1_circle(u, v, w_u, w_v) - w1_circle = ot.w_circle(u, v, w_u, w_v, p=1) + wass1_circle = ot.wasserstein1_circle(u, v, w_u, w_v) + w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1) M2 = M1**2 wass2 = ot.emd2(w_u, w_v, M2) wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2) - w2_circle = ot.w_circle(u, v, w_u, w_v, p=2) + w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2) # check loss is similar np.testing.assert_allclose(wass1, wass1_bsc) @@ -273,8 +273,8 @@ def test_wasserstein1d_circle_devices(nx): xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) - w1 = ot.w_circle(xb, xb, rho_ub, rho_vb, p=1) - w2_bsc = ot.w_circle(xb, xb, rho_ub, rho_vb, p=2) + w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1) + w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2) nx.assert_same_dtype_device(xb, w1) nx.assert_same_dtype_device(xb, w2_bsc) diff --git a/test/test_backend.py b/test/test_backend.py index 3628f6123..aaf985172 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -282,6 +282,20 @@ def test_empty_backend(): nx.array_equal(M, M) with pytest.raises(NotImplementedError): nx.is_floating_point(M) + with pytest.raises(NotImplementedError): + nx.tile(M, (10, 1)) + with pytest.raises(NotImplementedError): + nx.floor(M) + with pytest.raises(NotImplementedError): + nx.prod(M) + with pytest.raises(NotImplementedError): + nx.sort2(M) + with pytest.raises(NotImplementedError): + nx.qr(M) + with pytest.raises(NotImplementedError): + nx.atan2(v, v) + with pytest.raises(NotImplementedError): + nx.transpose(M, 0, 1) def test_func_backends(nx): @@ -603,6 +617,38 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("isfinite") + A = nx.tile(vb, (10, 1)) + lst_b.append(nx.to_numpy(A)) + lst_name.append("tile") + + A = nx.floor(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("floor") + + A = nx.prod(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("prod") + + A, B = nx.sort2(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("sort2 sort") + lst_b.append(nx.to_numpy(B)) + lst_name.append("sort2 argsort") + + A, B = nx.qr(Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("QR Q") + lst_b.append(nx.to_numpy(B)) + lst_name.append("QR R") + + A = nx.atan2(vb, vb) + lst_b.append(nx.to_numpy(A)) + lst_name.append("atan2") + + A = nx.transpose(Mb, 0, 1) + lst_b.append(nx.to_numpy(A)) + lst_name.append("transpose") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( From ea73a5d0c5fbdc5e6ce8f5a282dc1d615420bd43 Mon Sep 17 00:00:00 2001 From: Clement Date: Thu, 9 Feb 2023 16:16:58 +0100 Subject: [PATCH 04/16] Wasserstein on the circle wrt Unif --- examples/backends/plot_ssw_unif_torch.py | 153 ++++++++++++++++++ ....py => plot_compute_wasserstein_circle.py} | 34 ++++ ot/__init__.py | 10 +- ot/lp/__init__.py | 6 +- ot/lp/solver_1d.py | 80 ++++++++- ot/sliced.py | 86 +++++++++- test/test_1d_solver.py | 45 ++++++ test/test_sliced.py | 46 ++++++ 8 files changed, 451 insertions(+), 9 deletions(-) create mode 100644 examples/backends/plot_ssw_unif_torch.py rename examples/{plot_compute_w1_circle.py => plot_compute_wasserstein_circle.py} (78%) diff --git a/examples/backends/plot_ssw_unif_torch.py b/examples/backends/plot_ssw_unif_torch.py new file mode 100644 index 000000000..d1de5a989 --- /dev/null +++ b/examples/backends/plot_ssw_unif_torch.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +r""" +================================================ +Spherical Sliced-Wasserstein Embedding on Sphere +================================================ + +Here, we aim at transforming samples into a uniform +distribution on the sphere by minimizing SSW: + +.. math:: + \min_{x} SSW_2(\nu, \frac{1}{n}\sum_{i=1}^n \delta_{x_i}) + +where :math:`\nu=\mathrm{Unif}(S^1)`. + +""" + +# Author: Clément Bonet +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +import numpy as np +import matplotlib.pyplot as pl +import matplotlib.animation as animation +import torch +import torch.nn.functional as F + +import ot + + +# %% +# Data generation +# --------------- + +torch.manual_seed(1) + +N = 1000 +x0 = torch.rand(N, 3) +x0 = F.normalize(x0, dim=-1) + + +# %% +# Plot data +# --------- + +def plot_sphere(ax): + xlist = np.linspace(-1.0, 1.0, 50) + ylist = np.linspace(-1.0, 1.0, 50) + r = np.linspace(1.0, 1.0, 50) + X, Y = np.meshgrid(xlist, ylist) + + Z = np.sqrt(r**2 - X**2 - Y**2) + + ax.plot_wireframe(X, Y, Z, color="gray", alpha=.3) + ax.plot_wireframe(X, Y, -Z, color="gray", alpha=.3) # Now plot the bottom half + + +# plot the distributions +pl.figure(1) +ax = pl.axes(projection='3d') +plot_sphere(ax) +ax.scatter(x0[:, 0], x0[:, 1], x0[:, 2], label='Data samples', alpha=0.5) +ax.set_title('Data distribution') +ax.legend() + + +# %% +# Gradient descent +# ---------------- + +x = x0.clone() +x.requires_grad_(True) + +n_iter = 500 +lr = 100 + +losses = [] +xvisu = torch.zeros(n_iter, N, 3) + +for i in range(n_iter): + sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500) + grad_x = torch.autograd.grad(sw, x)[0] + + x = x - lr * grad_x + x = F.normalize(x, p=2, dim=1) + + losses.append(sw.item()) + xvisu[i, :, :] = x.detach().clone() + + if i % 100 == 0: + print("Iter: {:3d}, loss={}".format(i, losses[-1])) + +pl.figure(1) +pl.semilogy(losses) +pl.grid() +pl.title('SSW') +pl.xlabel("Iterations") + + +# %% +# Plot trajectories of generated samples along iterations +# ------------------------------------------------------- + +ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499] + +fig = pl.figure(3, (10, 10)) +for i in range(9): + # pl.subplot(3, 3, i + 1) + # ax = pl.axes(projection='3d') + ax = fig.add_subplot(3, 3, i + 1, projection='3d') + plot_sphere(ax) + ax.scatter(xvisu[ivisu[i], :, 0], xvisu[ivisu[i], :, 1], xvisu[ivisu[i], :, 2], label='Data samples', alpha=0.5) + ax.set_title('Iter. {}'.format(ivisu[i])) + #ax.axis("off") + if i == 0: + ax.legend() + + +# %% +# Animate trajectories of generated samples along iteration +# ------------------------------------------------------- + +pl.figure(4, (8, 8)) + + +def _update_plot(i): + i = 3 * i + pl.clf() + ax = pl.axes(projection='3d') + plot_sphere(ax) + ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples$', alpha=0.5) + ax.axis("off") + ax.set_xlim((-1.5, 1.5)) + ax.set_ylim((-1.5, 1.5)) + ax.set_title('Iter. {}'.format(i)) + return 1 + + +print(xvisu.shape) + +i = 0 +ax = pl.axes(projection='3d') +plot_sphere(ax) +ax.scatter(xvisu[i, :, 0], xvisu[i, :, 1], xvisu[i, :, 2], label='Data samples from $G\#\mu_n$', alpha=0.5) +ax.axis("off") +ax.set_xlim((-1.5, 1.5)) +ax.set_ylim((-1.5, 1.5)) +ax.set_title('Iter. {}'.format(ivisu[i])) + + +ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000) +# %% diff --git a/examples/plot_compute_w1_circle.py b/examples/plot_compute_wasserstein_circle.py similarity index 78% rename from examples/plot_compute_w1_circle.py rename to examples/plot_compute_wasserstein_circle.py index 801411fad..a4f8b7c41 100644 --- a/examples/plot_compute_w1_circle.py +++ b/examples/plot_compute_wasserstein_circle.py @@ -125,3 +125,37 @@ def pdf_von_Mises(theta, mu, kappa): pl.legend() pl.xlabel(r"$\mu_{\mathrm{source}}$") pl.show() + + +######################################################################## +# Wasserstein distance between von Mises and uniform for different kappa +# ---------------------------------------------------------------------- +# When :math:`\kappa=0`, the von Mises distribution is the uniform distribution on :math:`S^1`. + +#%% Compute Wasserstein between Von Mises and uniform + +kappas = np.logspace(-5, 2, 100) +n_try = 20 + +xts = np.zeros((n_try, 100, 500)) +for i in range(n_try): + for k, kappa in enumerate(kappas): + # np.random.vonmises deals with data on [-pi, pi[ + xt = np.random.vonmises(0, kappa, size=(500,)) + np.pi + xts[i, k] = xt / (2 * np.pi) + +L_w2 = np.zeros((n_try, 100)) +for i in range(n_try): + L_w2[i] = ot.wasserstein2_unif_circle(xts[i].T) + +m_w2 = np.mean(L_w2, axis=0) +std_w2 = np.std(L_w2, axis=0) + +pl.figure(1) +pl.plot(kappas, m_w2) +pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5) +pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$") +pl.xlabel(r"$\kappa$") +pl.show() + +# %% diff --git a/ot/__init__.py b/ot/__init__.py index 95e36b9f4..823c08e9b 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -38,12 +38,15 @@ from . import gaussian # OT functions -from .lp import emd, emd2, emd_1d, emd2_1d, wasserstein_1d, binary_search_circle, wasserstein1_circle, wasserstein_circle +from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, + binary_search_circle, wasserstein1_circle, + wasserstein_circle, wasserstein2_unif_circle) from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm -from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance, sliced_wasserstein_sphere +from .sliced import (sliced_wasserstein_distance, max_sliced_wasserstein_distance, + sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif) from .gromov import (gromov_wasserstein, gromov_wasserstein2, gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport @@ -65,4 +68,5 @@ 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'factored_optimal_transport', 'solve', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', - 'binary_search_circle', 'wasserstein1_circle', 'wasserstein_circle'] + 'binary_search_circle', 'wasserstein1_circle', 'wasserstein_circle', + 'wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index d6fe123d7..3618ee5c0 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -20,7 +20,9 @@ # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted -from .solver_1d import emd_1d, emd2_1d, wasserstein_1d, binary_search_circle, wasserstein1_circle, wasserstein_circle +from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d, + binary_search_circle, wasserstein1_circle, + wasserstein_circle, wasserstein2_unif_circle) from ..utils import dist, list_to_array from ..utils import parmap @@ -28,7 +30,7 @@ __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter', - 'binary_search_circle', 'wasserstein1_circle', 'wasserstein_circle'] + 'binary_search_circle', 'wasserstein1_circle', 'wasserstein_circle', 'wasserstein2_unif_circle'] def check_number_threads(numThreads): diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 8fc50bec3..2df37bb9c 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -614,6 +614,11 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 require_sort: bool, optional If True, sort the values. + Returns + ------- + loss: float + Cost associated to the optimal transportation + Examples -------- >>> u = np.array([[0.2,0.5,0.8]])%1 @@ -726,6 +731,11 @@ def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, requ require_sort: bool, optional If True, sort the values. + Returns + ------- + loss: float + Cost associated to the optimal transportation + Examples -------- >>> u = np.array([[0.2,0.5,0.8]])%1 @@ -799,7 +809,7 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, the binary search algorithm proposed in [44] otherwise. Samples need to be in :math:`S^1\cong [0,1[`. - .. math: + .. math:: OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q Parameters @@ -827,6 +837,11 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort: bool, optional If True, sort the values. + Returns + ------- + loss: float + Cost associated to the optimal transportation + Examples -------- >>> u = np.array([[0.2,0.5,0.8]])%1 @@ -847,3 +862,66 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, return binary_search_circle(u_values, v_values, u_weights, v_weights, p=p, Lm=Lm, Lp=Lp, tm=tm, tp=tp, eps=eps, require_sort=require_sort) + + +def wasserstein2_unif_circle(u_values, u_weights=None): + r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1` + + .. math:: + W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12} + + where: + + - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}` + + Parameters + ---------- + u_values: ndarray, shape (n, ...) + Samples + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + + Returns + ------- + loss: float + Cost associated to the optimal transportation + + Examples + -------- + >>> x0 = np.array([[0], [0.2], [0.4]]) + >>> wasserstein2_unif_circle(x0) + array([0.02111111]) + + References + ---------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + + if u_weights is not None: + nx = get_backend(u_values, u_weights) + else: + nx = get_backend(u_values) + + n = u_values.shape[0] + + u_values = u_values % 1 + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1. / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + u_values = nx.sort(u_values, 0) + u_cdf = nx.cumsum(u_weights, 0) + u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) + + cpt1 = nx.sum(u_weights * u_values**2, axis=0) + u_mean = nx.sum(u_weights * u_values, axis=0) + + ns = 1 - u_weights - 2 * u_cdf[:-1] + cpt2 = nx.sum(u_values * u_weights * ns, axis=0) + + return cpt1 - u_mean**2 + cpt2 + 1 / 12 diff --git a/ot/sliced.py b/ot/sliced.py index aa21c9315..61b0e02c2 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -270,9 +270,8 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, - :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}` - - Parameters: - ----------- + Parameters + ---------- X_s: ndarray, shape (n_samples_a, dim) Samples in the source domain X_t: ndarray, shape (n_samples_b, dim) @@ -359,3 +358,84 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, if log: return res, {"projections": projections, "projected_emds": projected_emd} return res + + +def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log=False): + r"""Compute the 2-spherical sliced wasserstein w.r.t. a uniform distribution. + + .. math:: + SSW_2(\mu_n, \nu) + + where + + - :math:`\mu_n=\sum_{i=1}^n \alpha_i \delta_{x_i}` + - :math:`\nu=\mathrm{Unif}(S^1)` + + Parameters + ---------- + X_s: ndarray, shape (n_samples_a, dim) + Samples in the source domain + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + seed: int or RandomState or None, optional + Seed used for random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + cost: float + Spherical Sliced Wasserstein Cost + log: dict, optional + log dictionary return only if log==True in parameters + + Examples + --------- + >>> np.random.seed(42) + >>> x0 = np.random.randn(500,3) + >>> x0 = x0 / np.sqrt(np.sum(x0**2, -1, keepdims=True)) + >>> sliced_wasserstein_sphere_unif(x0, seed=42) + 0.017335828268020156 + + References: + ----------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + from .lp import wasserstein2_unif_circle + + if a is not None: + nx = get_backend(X_s, a) + else: + nx = get_backend(X_s) + + n, d = X_s.shape + + if nx.any(nx.abs(nx.sum(X_s**2, axis=-1) - 1) > 10**(-4)): + raise ValueError("X_s is not on the sphere.") + + # Uniforms and independent samples on the Stiefel manifold V_{d,2} + if isinstance(seed, np.random.RandomState) and str(nx) == 'numpy': + Z = seed.randn(n_projections, d, 2) + else: + if seed is not None: + nx.seed(seed) + Z = nx.randn(n_projections, d, 2, type_as=X_s) + + projections, _ = nx.qr(Z) + + # Projection on S^1 + # Projection on plane + Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, 1, 2)[:, None], X_s[:, :, None]), (n_projections, 2, n)), 1, 2) + # Projection on sphere + Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) + # Get coords + Xps_coords = (nx.atan2(-Xps[:, :, 1], -Xps[:, :, 0]) + np.pi) / (2 * np.pi) + + projected_emd = wasserstein2_unif_circle(Xps_coords.T, u_weights=a) + res = nx.mean(projected_emd) ** (1 / 2) + + if log: + return res, {"projections": projections, "projected_emds": projected_emd} + return res diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 593727d40..a55b04831 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -278,3 +278,48 @@ def test_wasserstein1d_circle_devices(nx): nx.assert_same_dtype_device(xb, w1) nx.assert_same_dtype_device(xb, w2_bsc) + + +def test_wasserstein_1d_unif_circle(): + # test wasserstein unif_circle give similar results as wasserstein1d + n = 20 + m = 50000 + + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + # w_u = rng.uniform(0., 1., n) + # w_u = w_u / w_u.sum() + + w_u = ot.utils.unif(n) + + w_v = ot.utils.unif(m) + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + wass2 = ot.emd2(w_u, w_v, M1**2) + + wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15) + wass2_unif_circle = ot.wasserstein2_unif_circle(u, w_u) + + # check loss is similar + np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3) + np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-3) + + +def test_wasserstein1d_unif_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp) + + w2 = ot.wasserstein2_unif_circle(xb, rho_ub) + + nx.assert_same_dtype_device(xb, w2) diff --git a/test/test_sliced.py b/test/test_sliced.py index 464a6c0b4..ae878901a 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -406,3 +406,49 @@ def test_sliced_sphere_backend_type_devices(nx): valb = ot.sliced_wasserstein_sphere(xb, yb) nx.assert_same_dtype_device(xb, valb) + + +def test_sliced_sphere_unif(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng) + + +def test_sliced_sphere_unif_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_sphere_unif(x, u, 10, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert projections.shape[0] == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 + + +def test_sliced_sphere_unif_backend_type_devices(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 3) + x = x / np.sqrt(np.sum(x**2, -1, keepdims=True)) + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb = nx.from_numpy(x, type_as=tp) + + valb = ot.sliced_wasserstein_sphere_unif(xb) + + nx.assert_same_dtype_device(xb, valb) From 8d5b355dbd7d9e3bccb10ce9e38cf5cac61348ac Mon Sep 17 00:00:00 2001 From: Clement Date: Thu, 9 Feb 2023 17:17:22 +0100 Subject: [PATCH 05/16] Example SSW unif --- RELEASES.md | 5 +++++ ot/sliced.py | 5 +++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 4ed362556..3e5927e79 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,6 +4,11 @@ #### New features +- Added the spherical sliced-Wasserstein discrepancy in `ot.sliced_wasserstein_sphere` and `ot.sliced_wasserstein_sphere_unif` + examples (PR #434) +- Added the Wasserstein distance on the circle in ``ot.wasserstein_circle`` (PR #434) +- Added the 1-Wasserstein distance on the circle in `ot.wasserstein1_circle` (PR #434) +- Added the Wasserstein distance on the circle (for p>=1) in `ot.binary_search_circle` + examples (PR #434) +- Added the 2-Wasserstein distance on the circle w.r.t a uniform distribution in `ot.wasserstein2_unif_circle` (PR #434) - Added Bures Wasserstein distance in `ot.gaussian` (PR ##428) - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) - Added Free Support Sinkhorn Barycenter + example (PR #387) diff --git a/ot/sliced.py b/ot/sliced.py index 61b0e02c2..d966b0fed 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -396,8 +396,9 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log >>> np.random.seed(42) >>> x0 = np.random.randn(500,3) >>> x0 = x0 / np.sqrt(np.sum(x0**2, -1, keepdims=True)) - >>> sliced_wasserstein_sphere_unif(x0, seed=42) - 0.017335828268020156 + >>> ssw = sliced_wasserstein_sphere_unif(x0, seed=42) + >>> np.allclose(sliced_wasserstein_sphere_unif(x0, seed=42), 0.01734, atol=1e-3) + True References: ----------- From 53538365b42b18b02c30369a6aaf411358266bcb Mon Sep 17 00:00:00 2001 From: Clement Date: Fri, 10 Feb 2023 18:18:33 +0100 Subject: [PATCH 06/16] pep8 --- ot/backend.py | 2 +- ot/lp/solver_1d.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 1273f8bfe..a8dd6b038 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -915,7 +915,7 @@ def prod(self, a, axis=None): r""" Return the product of all elements. - See: https://pytorch.org/docs/stable/generated/torch.prod.html + See: https://pytorch.org/docs/stable/generated/torch.prod.html """ raise NotImplementedError() diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 2df37bb9c..1c737f993 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -579,7 +579,7 @@ def Cost(theta, u_values, v_values, u_cdf, v_cdf, p): def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): - r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. + r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. Samples need to be in :math:`S^1\cong [0,1[`. .. math:: From c5f12da99c40ddff76f6a1d60bc67b0f7d46eb40 Mon Sep 17 00:00:00 2001 From: Clement Date: Mon, 13 Feb 2023 11:32:01 +0100 Subject: [PATCH 07/16] np.linalg.qr for numpy < 1.22 by batch + add python3.11 to tests --- .github/workflows/build_tests.yml | 6 +++--- ot/backend.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index ce725c6a4..cea1d079d 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -22,7 +22,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v1 @@ -93,7 +93,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v1 @@ -120,7 +120,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v1 diff --git a/ot/backend.py b/ot/backend.py index a8dd6b038..747bb1a2e 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1254,6 +1254,20 @@ def sort2(self, a, axis=-1): return self.sort(a, axis), self.argsort(a, axis) def qr(self, a): + np_version = tuple([int(k) for k in np.__version__.split(".")]) + if np_version <= (1, 21, 6): + M, N = a.shape[-2], a.shape[-1] + K = min(M, N) + + if len(a.shape) >= 3: + n = a.shape[0] + else: + n = 1 + qs, rs = np.zeros((n, M, K)), np.zeros((n, K, N)) + + for i in range(a.shape[0]): + qs[i], rs[i] = np.linalg.qr(a[i]) + return qs, rs return np.linalg.qr(a) def atan2(self, a, b): From 85598eb20e3338543d7c9dfbc75fc54e507ff87a Mon Sep 17 00:00:00 2001 From: Clement Date: Mon, 13 Feb 2023 11:38:23 +0100 Subject: [PATCH 08/16] np qr --- ot/backend.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 747bb1a2e..79c290774 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1261,12 +1261,15 @@ def qr(self, a): if len(a.shape) >= 3: n = a.shape[0] + + qs, rs = np.zeros((n, M, K)), np.zeros((n, K, N)) + + for i in range(a.shape[0]): + qs[i], rs[i] = np.linalg.qr(a[i]) + else: - n = 1 - qs, rs = np.zeros((n, M, K)), np.zeros((n, K, N)) + return np.linalg.qr(a) - for i in range(a.shape[0]): - qs[i], rs[i] = np.linalg.qr(a[i]) return qs, rs return np.linalg.qr(a) From fec051168792b95447b8c0b93bd7f967737732a8 Mon Sep 17 00:00:00 2001 From: Clement Date: Mon, 13 Feb 2023 18:20:51 +0100 Subject: [PATCH 09/16] rm test python 3.11 --- .github/workflows/build_tests.yml | 6 +++--- ot/backend.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index cea1d079d..ce725c6a4 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -22,7 +22,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v1 @@ -93,7 +93,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v1 @@ -120,7 +120,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - uses: actions/checkout@v1 diff --git a/ot/backend.py b/ot/backend.py index 79c290774..5c1812d51 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1255,7 +1255,7 @@ def sort2(self, a, axis=-1): def qr(self, a): np_version = tuple([int(k) for k in np.__version__.split(".")]) - if np_version <= (1, 21, 6): + if np_version < (1, 22, 0): M, N = a.shape[-2], a.shape[-1] K = min(M, N) From 66fccf0b4d6d9dbe0f6b1151bee88aec9203b84f Mon Sep 17 00:00:00 2001 From: Clement Date: Tue, 21 Feb 2023 14:22:33 +0100 Subject: [PATCH 10/16] update names, tests, backend transpose --- CONTRIBUTORS.md | 2 +- README.md | 3 +- RELEASES.md | 9 ++-- ot/__init__.py | 8 ++-- ot/backend.py | 41 ++++++++--------- ot/lp/__init__.py | 6 +-- ot/lp/solver_1d.py | 99 +++++++++++++++++++++++++++++++++--------- ot/sliced.py | 27 +++++------- test/test_1d_solver.py | 36 ++++++++++++--- test/test_backend.py | 4 +- test/test_sliced.py | 4 +- 11 files changed, 155 insertions(+), 84 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index ccd0fd1a1..1437821fa 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -41,7 +41,7 @@ The contributors to this library are: * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) -* [Clément Bonet](https://clbonet.github.io) (Spherical Sliced-Wasserstein) +* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein) ## Acknowledgments diff --git a/README.md b/README.md index 7f571e1fa..d5e68549b 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,8 @@ POT provides the following generic OT solvers (links to examples): * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. +* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45] +* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] * [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. @@ -294,7 +296,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. - [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. [The statistics of circular optimal transport.](https://arxiv.org/abs/2103.15426) Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. diff --git a/RELEASES.md b/RELEASES.md index 3e5927e79..f8ef653d9 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -4,11 +4,10 @@ #### New features -- Added the spherical sliced-Wasserstein discrepancy in `ot.sliced_wasserstein_sphere` and `ot.sliced_wasserstein_sphere_unif` + examples (PR #434) -- Added the Wasserstein distance on the circle in ``ot.wasserstein_circle`` (PR #434) -- Added the 1-Wasserstein distance on the circle in `ot.wasserstein1_circle` (PR #434) -- Added the Wasserstein distance on the circle (for p>=1) in `ot.binary_search_circle` + examples (PR #434) -- Added the 2-Wasserstein distance on the circle w.r.t a uniform distribution in `ot.wasserstein2_unif_circle` (PR #434) +- Added the spherical sliced-Wasserstein discrepancy in `ot.sliced.sliced_wasserstein_sphere` and `ot.sliced.sliced_wasserstein_sphere_unif` + examples (PR #434) +- Added the Wasserstein distance on the circle in ``ot.lp.solver_1d.wasserstein_circle`` (PR #434) +- Added the Wasserstein distance on the circle (for p>=1) in `ot.lp.solver_1d.binary_search_circle` + examples (PR #434) +- Added the 2-Wasserstein distance on the circle w.r.t a uniform distribution in `ot.lp.solver_1d.semidiscrete_wasserstein2_unif_circle` (PR #434) - Added Bures Wasserstein distance in `ot.gaussian` (PR ##428) - Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376) - Added Free Support Sinkhorn Barycenter + example (PR #387) diff --git a/ot/__init__.py b/ot/__init__.py index 823c08e9b..45d5cfa44 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -39,8 +39,8 @@ # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, - binary_search_circle, wasserstein1_circle, - wasserstein_circle, wasserstein2_unif_circle) + binary_search_circle, wasserstein_circle, + semidiscrete_wasserstein2_unif_circle) from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import (sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2) @@ -68,5 +68,5 @@ 'max_sliced_wasserstein_distance', 'weak_optimal_transport', 'factored_optimal_transport', 'solve', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', - 'binary_search_circle', 'wasserstein1_circle', 'wasserstein_circle', - 'wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] + 'binary_search_circle', 'wasserstein_circle', + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] diff --git a/ot/backend.py b/ot/backend.py index 5c1812d51..9c4246300 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -536,7 +536,7 @@ def concatenate(self, arrays, axis=0): def zero_pad(self, a, pad_width, value=0): r""" - Pads a tensor. + Pads a tensor with a given value (0 by default). This function follows the api from :any:`numpy.pad` @@ -915,7 +915,7 @@ def prod(self, a, axis=None): r""" Return the product of all elements. - See: https://pytorch.org/docs/stable/generated/torch.prod.html + See: https://numpy.org/doc/stable/reference/generated/numpy.prod.html """ raise NotImplementedError() @@ -931,7 +931,7 @@ def qr(self, a): r""" Return the QR factorization - See: https://pytorch.org/docs/stable/generated/torch.linalg.qr.html + See: https://numpy.org/doc/stable/reference/generated/numpy.linalg.qr.html """ raise NotImplementedError() @@ -943,11 +943,11 @@ def atan2(self, a, b): """ raise NotImplementedError() - def transpose(self, a, dim0, dim1): + def transpose(self, a, axes=None): r""" Returns a tensor that is a transposed version of a. The given dimensions dim0 and dim1 are swapped. - See: https://pytorch.org/docs/stable/generated/torch.transpose.html + See: https://numpy.org/doc/stable/reference/generated/numpy.transpose.html """ raise NotImplementedError() @@ -1276,11 +1276,8 @@ def qr(self, a): def atan2(self, a, b): return np.arctan2(a, b) - def transpose(self, a, dim0, dim1): - dims = list(range(len(a.shape))) - dims[dim0], dims[dim1] = dim1, dim0 - return np.transpose(a, axes=dims) - # return np.transpose(a, axes=[0, dim1, dim0]) + def transpose(self, a, axes=None): + return np.transpose(a, axes) class JaxBackend(Backend): @@ -1626,10 +1623,8 @@ def qr(self, a): def atan2(self, a, b): return jnp.arctan2(a, b) - def transpose(self, a, dim0, dim1): - dims = list(range(len(a.shape))) - dims[dim0], dims[dim1] = dim1, dim0 - return jnp.transpose(a, axes=dims) + def transpose(self, a, axes=None): + return jnp.transpose(a, axes) class TorchBackend(Backend): @@ -2072,8 +2067,10 @@ def qr(self, a): def atan2(self, a, b): return torch.atan2(a, b) - def transpose(self, a, dim0, dim1): - return torch.transpose(a, dim0, dim1) + def transpose(self, a, axes=None): + if axes is None: + axes = tuple(range(a.ndim)[::-1]) + return torch.permute(a, axes) class CupyBackend(Backend): # pragma: no cover @@ -2443,10 +2440,8 @@ def qr(self, a): def atan2(self, a, b): return cp.arctan2(a, b) - def transpose(self, a, dim0, dim1): - dims = list(range(len(a.shape))) - dims[dim0], dims[dim1] = dim1, dim0 - return cp.transpose(a, axes=dims) + def transpose(self, a, axes=None): + return cp.transpose(a, axes) class TensorflowBackend(Backend): @@ -2829,7 +2824,5 @@ def qr(self, a): def atan2(self, a, b): return tf.math.atan2(a, b) - def transpose(self, a, dim0, dim1): - dims = list(range(len(a.shape))) - dims[dim0], dims[dim1] = dim1, dim0 - return tf.transpose(a, perm=dims) + def transpose(self, a, axes=None): + return tf.transpose(a, perm=axes) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 3618ee5c0..7d0640fba 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -21,8 +21,8 @@ # import compiled emd from .emd_wrap import emd_c, check_result, emd_1d_sorted from .solver_1d import (emd_1d, emd2_1d, wasserstein_1d, - binary_search_circle, wasserstein1_circle, - wasserstein_circle, wasserstein2_unif_circle) + binary_search_circle, wasserstein_circle, + semidiscrete_wasserstein2_unif_circle) from ..utils import dist, list_to_array from ..utils import parmap @@ -30,7 +30,7 @@ __all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted', 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter', - 'binary_search_circle', 'wasserstein1_circle', 'wasserstein_circle', 'wasserstein2_unif_circle'] + 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle'] def check_number_threads(numThreads): diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 1c737f993..66a3e19fc 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -407,7 +407,7 @@ def roll_cols(M, shifts): return nx.take_along_axis(M, arange2, 1) -def dCost(theta, u_values, v_values, u_cdf, v_cdf, p=2): +def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): r""" Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1]) Parameters @@ -494,7 +494,7 @@ def dCost(theta, u_values, v_values, u_cdf, v_cdf, p=2): return dCp.reshape(-1, 1), dCm.reshape(-1, 1) -def Cost(theta, u_values, v_values, u_cdf, v_cdf, p): +def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): r""" Computes the the cost (Equation (6.2) of [1]) Parameters @@ -578,9 +578,13 @@ def Cost(theta, u_values, v_values, u_cdf, v_cdf, p): def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, - Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): + Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True, + log=False): r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. - Samples need to be in :math:`S^1\cong [0,1[`. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value % 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates + using e.g. the atan2 function. .. math:: W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q @@ -589,12 +593,19 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v` + For values :math:`x=(x_1,x_2)\in S^1`, first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + The function runs on backend but tensorflow is not supported. + Parameters ---------- u_values : ndarray, shape (n, ...) - samples in the source domain + samples in the source domain (coordinates on [0,1[) v_values : ndarray, shape (n, ...) - samples in the target domain + samples in the target domain (coordinates on [0,1[) u_weights : ndarray, shape (n, ...), optional samples weights in the source domain v_weights : ndarray, shape (n, ...), optional @@ -613,11 +624,15 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 Stopping condition require_sort: bool, optional If True, sort the values. + log: bool, optional + If True, returns also the optimal theta Returns ------- loss: float Cost associated to the optimal transportation + log: dict, optional + log dictionary returned only if log==True in parameters Examples -------- @@ -646,6 +661,11 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 if len(v_values.shape) == 1: v_values = nx.reshape(v_values, (m, 1)) + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of dimensions {} and {} respectively given".format(u_values.shape[1], + v_values.shape[1])) + u_values = u_values % 1 v_values = v_values % 1 @@ -688,17 +708,17 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 while nx.any(1 - done): cpt += 1 - dCp, dCm = dCost(tc, u_values, v_values, u_cdf, v_cdf, p) + dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) done = ((dCp * dCm) <= 0) * 1 mask = ((tp - tm) < eps / L) * (1 - done) if nx.any(mask): # can probably be improved by computing only relevant values - dCptp, dCmtp = dCost(tp, u_values, v_values, u_cdf, v_cdf, p) - dCptm, dCmtm = dCost(tm, u_values, v_values, u_cdf, v_cdf, p) - Ctm = Cost(tm, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1) - Ctp = Cost(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1) + dCptp, dCmtp = derivative_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p) + dCptm, dCmtm = derivative_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p) + Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1) + Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape(-1, 1) mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) tc[mask_end > 0] = ((Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp))[mask_end > 0] @@ -708,12 +728,20 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0] tc[((1 - mask) * (1 - done)) > 0] = (tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0]) / 2 - return Cost(tc, u_values, v_values, u_cdf, v_cdf, p) + w = ot_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) + + if log: + return w, {"optimal_theta": tc[:, 0]} + return w def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True): r"""Computes the 1-Wasserstein distance on the circle using the level median [45]. - Samples need to be in :math:`S^1\cong [0,1[`. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value % 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates + using e.g. the atan2 function. + The function runs on backend but tensorflow is not supported. .. math:: W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t @@ -721,9 +749,9 @@ def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, requ Parameters ---------- u_values : ndarray, shape (n, ...) - samples in the source domain + samples in the source domain (coordinates on [0,1[) v_values : ndarray, shape (n, ...) - samples in the target domain + samples in the target domain (coordinates on [0,1[) u_weights : ndarray, shape (n, ...), optional samples weights in the source domain v_weights : ndarray, shape (n, ...), optional @@ -762,6 +790,11 @@ def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, requ if len(v_values.shape) == 1: v_values = nx.reshape(v_values, (m, 1)) + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of dimensions {} and {} respectively given".format(u_values.shape[1], + v_values.shape[1])) + u_values = u_values % 1 v_values = v_values % 1 @@ -807,17 +840,34 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, Lm=10, Lp=10, tm=-1, tp=1, eps=1e-6, require_sort=True): r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or the binary search algorithm proposed in [44] otherwise. - Samples need to be in :math:`S^1\cong [0,1[`. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value % 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates + using e.g. the atan2 function. + + General loss returned: .. math:: OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + For p=1, [45] + + .. math:: + W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + + For values :math:`x=(x_1,x_2)\in S^1`, first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + The function runs on backend but tensorflow is not supported. + Parameters ---------- u_values : ndarray, shape (n, ...) - samples in the source domain + samples in the source domain (coordinates on [0,1[) v_values : ndarray, shape (n, ...) - samples in the target domain + samples in the target domain (coordinates on [0,1[) u_weights : ndarray, shape (n, ...), optional samples weights in the source domain v_weights : ndarray, shape (n, ...), optional @@ -864,8 +914,12 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=require_sort) -def wasserstein2_unif_circle(u_values, u_weights=None): +def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1` + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value % 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates + using e.g. the atan2 function. .. math:: W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12} @@ -874,6 +928,11 @@ def wasserstein2_unif_circle(u_values, u_weights=None): - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}` + For values :math:`x=(x_1,x_2)\in S^1`, first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + Parameters ---------- u_values: ndarray, shape (n, ...) @@ -889,7 +948,7 @@ def wasserstein2_unif_circle(u_values, u_weights=None): Examples -------- >>> x0 = np.array([[0], [0.2], [0.4]]) - >>> wasserstein2_unif_circle(x0) + >>> semidiscrete_wasserstein2_unif_circle(x0) array([0.02111111]) References diff --git a/ot/sliced.py b/ot/sliced.py index d966b0fed..7d8ea7a38 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -13,6 +13,7 @@ import numpy as np from .backend import get_backend, NumpyBackend from .utils import list_to_array +from .lp import wasserstein_circle, semidiscrete_wasserstein2_unif_circle def get_random_projections(d, n_projections, seed=None, backend=None, type_as=None): @@ -261,7 +262,7 @@ def max_sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, p=2, seed=None, log=False): r""" - Compute the sliced-Wasserstein distance on the sphere. + Compute the spherical sliced-Wasserstein discrepancy. .. math:: SSW_p(\mu,\nu) = \left(\int_{\mathbb{V}_{d,2}} W_p^p(P^U_\#\mu, P^U_\#\nu)\ \mathrm{d}\sigma(U)\right)^{\frac{1}{p}} @@ -270,6 +271,8 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, - :math:`P^U_\# \mu` stands for the pushforwards of the projection :math:`\forall x\in S^{d-1},\ P^U(x) = \frac{U^Tx}{\|U^Tx\|_2}` + The function runs on backend but tensorflow is not supported. + Parameters ---------- X_s: ndarray, shape (n_samples_a, dim) @@ -287,13 +290,13 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, seed: int or RandomState or None, optional Seed used for random number generator log: bool, optional - if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + if True, sliced_wasserstein_sphere returns the projections used and their associated EMD. Returns ------- cost: float Spherical Sliced Wasserstein Cost - log : dict, optional + log: dict, optional log dictionary return only if log==True in parameters Examples @@ -308,8 +311,6 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, ---------- .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. """ - from .lp import wasserstein_circle - if a is not None and b is not None: nx = get_backend(X_s, X_t, a, b) else: @@ -339,16 +340,14 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, # Projection on S^1 # Projection on plane - # Xps = nx.reshape(nx.dot(nx.transpose(projections, 1, 2)[:, None], X_s[:, :, None]), (n_projections, n, 2)) ## numpy reshapes in the wrong order - # Xpt = nx.reshape(nx.dot(nx.transpose(projections, 1, 2)[:, None], X_t[:, :, None]), (n_projections, m, 2)) - Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, 1, 2)[:, None], X_s[:, :, None]), (n_projections, 2, n)), 1, 2) - Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, 1, 2)[:, None], X_t[:, :, None]), (n_projections, 2, m)), 1, 2) + Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) + Xpt = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_t[:, :, None]), (n_projections, 2, m)), (0, 2, 1)) # Projection on sphere Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True)) - # Get coords + # Get coordinates on [0,1[ Xps_coords = (nx.atan2(-Xps[:, :, 1], -Xps[:, :, 0]) + np.pi) / (2 * np.pi) Xpt_coords = (nx.atan2(-Xpt[:, :, 1], -Xpt[:, :, 0]) + np.pi) / (2 * np.pi) @@ -404,8 +403,6 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log ----------- .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. """ - from .lp import wasserstein2_unif_circle - if a is not None: nx = get_backend(X_s, a) else: @@ -428,13 +425,13 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log # Projection on S^1 # Projection on plane - Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, 1, 2)[:, None], X_s[:, :, None]), (n_projections, 2, n)), 1, 2) + Xps = nx.transpose(nx.reshape(nx.dot(nx.transpose(projections, (0, 2, 1))[:, None], X_s[:, :, None]), (n_projections, 2, n)), (0, 2, 1)) # Projection on sphere Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) - # Get coords + # Get coordinates on [0,1[ Xps_coords = (nx.atan2(-Xps[:, :, 1], -Xps[:, :, 0]) + np.pi) / (2 * np.pi) - projected_emd = wasserstein2_unif_circle(Xps_coords.T, u_weights=a) + projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a) res = nx.mean(projected_emd) ** (1 / 2) if log: diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index a55b04831..54518860f 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -239,7 +239,6 @@ def test_wasserstein_1d_circle(): wass1 = ot.emd2(w_u, w_v, M1) wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1) - wass1_circle = ot.wasserstein1_circle(u, v, w_u, w_v) w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1) M2 = M1**2 @@ -249,10 +248,7 @@ def test_wasserstein_1d_circle(): # check loss is similar np.testing.assert_allclose(wass1, wass1_bsc) - - np.testing.assert_allclose(wass1, wass1_circle, rtol=1e-2) np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2) - np.testing.assert_allclose(wass2, wass2_bsc) np.testing.assert_allclose(wass2, w2_circle) @@ -293,14 +289,13 @@ def test_wasserstein_1d_unif_circle(): # w_u = w_u / w_u.sum() w_u = ot.utils.unif(n) - w_v = ot.utils.unif(m) M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) wass2 = ot.emd2(w_u, w_v, M1**2) wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15) - wass2_unif_circle = ot.wasserstein2_unif_circle(u, w_u) + wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u) # check loss is similar np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-3) @@ -320,6 +315,33 @@ def test_wasserstein1d_unif_circle_devices(nx): xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp) - w2 = ot.wasserstein2_unif_circle(xb, rho_ub) + w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub) nx.assert_same_dtype_device(xb, w2) + + +def test_binary_search_circle_log(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n,) + v = rng.rand(m,) + + wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True) + optimal_thetas = log["optimal_theta"] + + assert optimal_thetas.shape[0] == 1 + + +def test_wasserstein_circle_bad_shape(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + v = rng.rand(m, 1) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=2) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=1) diff --git a/test/test_backend.py b/test/test_backend.py index aaf985172..fd9a7613a 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -295,7 +295,7 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.atan2(v, v) with pytest.raises(NotImplementedError): - nx.transpose(M, 0, 1) + nx.transpose(M) def test_func_backends(nx): @@ -645,7 +645,7 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("atan2") - A = nx.transpose(Mb, 0, 1) + A = nx.transpose(Mb) lst_b.append(nx.to_numpy(A)) lst_name.append("transpose") diff --git a/test/test_sliced.py b/test/test_sliced.py index ae878901a..f54c799b5 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -311,7 +311,7 @@ def test_sliced_sphere_bad_shapes(): _ = ot.sliced_wasserstein_sphere(x, y, u, u, 10, seed=rng) -def test_sliced_sphere(): +def test_sliced_sphere_values_on_the_sphere(): n = 100 rng = np.random.RandomState(0) @@ -408,7 +408,7 @@ def test_sliced_sphere_backend_type_devices(nx): nx.assert_same_dtype_device(xb, valb) -def test_sliced_sphere_unif(): +def test_sliced_sphere_unif_values_on_the_sphere(): n = 100 rng = np.random.RandomState(0) From 1234c964398e2ea3f847efc8ed5459c5c2708657 Mon Sep 17 00:00:00 2001 From: Clement Date: Tue, 21 Feb 2023 14:29:02 +0100 Subject: [PATCH 11/16] Comment error batchs --- ot/lp/solver_1d.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 66a3e19fc..16ec87fe1 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -663,8 +663,8 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 if u_values.shape[1] != v_values.shape[1]: raise ValueError( - "u and v must have the same number of dimensions {} and {} respectively given".format(u_values.shape[1], - v_values.shape[1])) + "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1], + v_values.shape[1])) u_values = u_values % 1 v_values = v_values % 1 @@ -792,8 +792,8 @@ def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, requ if u_values.shape[1] != v_values.shape[1]: raise ValueError( - "u and v must have the same number of dimensions {} and {} respectively given".format(u_values.shape[1], - v_values.shape[1])) + "u and v must have the same number of batchs {} and {} respectively given".format(u_values.shape[1], + v_values.shape[1])) u_values = u_values % 1 v_values = v_values % 1 From ad7676509905794d1880ef51f23159ac00dbcfd2 Mon Sep 17 00:00:00 2001 From: Clement Date: Tue, 21 Feb 2023 16:23:35 +0100 Subject: [PATCH 12/16] semidiscrete_wasserstein2_unif_circle example --- examples/plot_compute_wasserstein_circle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/plot_compute_wasserstein_circle.py b/examples/plot_compute_wasserstein_circle.py index a4f8b7c41..3ede96f3c 100644 --- a/examples/plot_compute_wasserstein_circle.py +++ b/examples/plot_compute_wasserstein_circle.py @@ -146,7 +146,7 @@ def pdf_von_Mises(theta, mu, kappa): L_w2 = np.zeros((n_try, 100)) for i in range(n_try): - L_w2[i] = ot.wasserstein2_unif_circle(xts[i].T) + L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T) m_w2 = np.mean(L_w2, axis=0) std_w2 = np.std(L_w2, axis=0) From 2412dbe49ffe3c4abc0a21122fbf1e10401cd315 Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 22 Feb 2023 09:02:08 +0100 Subject: [PATCH 13/16] torch permute method instead of torch.permute for previous versions --- ot/backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ot/backend.py b/ot/backend.py index 9c4246300..2b046802c 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -2070,7 +2070,8 @@ def atan2(self, a, b): def transpose(self, a, axes=None): if axes is None: axes = tuple(range(a.ndim)[::-1]) - return torch.permute(a, axes) + return a.permute(axes) + # return torch.permute(a, axes) class CupyBackend(Backend): # pragma: no cover From 0963c66ceb3620f7a214f19a41e63527d1fefbac Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 22 Feb 2023 14:03:58 +0100 Subject: [PATCH 14/16] update comments and doc --- ot/backend.py | 1 - ot/lp/solver_1d.py | 6 +++--- test/test_1d_solver.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 2b046802c..077924350 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -2071,7 +2071,6 @@ def transpose(self, a, axes=None): if axes is None: axes = tuple(range(a.ndim)[::-1]) return a.permute(axes) - # return torch.permute(a, axes) class CupyBackend(Backend): # pragma: no cover diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 16ec87fe1..848b026a9 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -582,7 +582,7 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 log=False): r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value % 1. + takes the value modulo 1. If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates using e.g. the atan2 function. @@ -738,7 +738,7 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 def wasserstein1_circle(u_values, v_values, u_weights=None, v_weights=None, require_sort=True): r"""Computes the 1-Wasserstein distance on the circle using the level median [45]. Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value % 1. + takes the value modulo 1. If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates using e.g. the atan2 function. The function runs on backend but tensorflow is not supported. @@ -841,7 +841,7 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, r"""Computes the Wasserstein distance on the circle using either [45] for p=1 or the binary search algorithm proposed in [44] otherwise. Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value % 1. + takes the value modulo 1. If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates using e.g. the atan2 function. diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index 54518860f..21abd1da5 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -277,7 +277,7 @@ def test_wasserstein1d_circle_devices(nx): def test_wasserstein_1d_unif_circle(): - # test wasserstein unif_circle give similar results as wasserstein1d + # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle n = 20 m = 50000 From bb26cf7fed04fb68dece0fb24b041388a3c974ec Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 22 Feb 2023 14:09:05 +0100 Subject: [PATCH 15/16] doc wasserstein circle model as [0,1[ --- ot/lp/solver_1d.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 848b026a9..024e1090e 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -583,7 +583,7 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, takes the value modulo 1. - If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates using e.g. the atan2 function. .. math:: @@ -593,7 +593,7 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v` - For values :math:`x=(x_1,x_2)\in S^1`, first get their coordinates with + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with .. math:: u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} @@ -842,7 +842,7 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, the binary search algorithm proposed in [44] otherwise. Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, takes the value modulo 1. - If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates + If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates using e.g. the atan2 function. General loss returned: @@ -855,7 +855,7 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, .. math:: W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t - For values :math:`x=(x_1,x_2)\in S^1`, first get their coordinates with + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with .. math:: u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} @@ -917,8 +917,8 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1` Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value % 1. - If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates using e.g. the atan2 function. .. math:: @@ -928,7 +928,7 @@ def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}` - For values :math:`x=(x_1,x_2)\in S^1`, first get their coordinates with + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with .. math:: u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} From 5a3f367e1359a7cd4dafcd8cd220918dc3dc32e9 Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 22 Feb 2023 15:25:34 +0100 Subject: [PATCH 16/16] Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn --- ot/lp/solver_1d.py | 8 +++++++- ot/sliced.py | 8 ++++---- ot/utils.py | 30 ++++++++++++++++++++++++++++++ test/test_utils.py | 10 ++++++++++ 4 files changed, 51 insertions(+), 5 deletions(-) diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 024e1090e..e7add89eb 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -598,6 +598,8 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1 .. math:: u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + using e.g. ot.utils.get_coordinate_circle(x) + The function runs on backend but tensorflow is not supported. Parameters @@ -860,6 +862,8 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1, .. math:: u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + using e.g. ot.utils.get_coordinate_circle(x) + The function runs on backend but tensorflow is not supported. Parameters @@ -931,7 +935,9 @@ def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with .. math:: - u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}, + + using e.g. ot.utils.get_coordinate_circle(x) Parameters ---------- diff --git a/ot/sliced.py b/ot/sliced.py index 7d8ea7a38..077ff0b36 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -12,7 +12,7 @@ import numpy as np from .backend import get_backend, NumpyBackend -from .utils import list_to_array +from .utils import list_to_array, get_coordinate_circle from .lp import wasserstein_circle, semidiscrete_wasserstein2_unif_circle @@ -348,8 +348,8 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50, Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True)) # Get coordinates on [0,1[ - Xps_coords = (nx.atan2(-Xps[:, :, 1], -Xps[:, :, 0]) + np.pi) / (2 * np.pi) - Xpt_coords = (nx.atan2(-Xpt[:, :, 1], -Xpt[:, :, 0]) + np.pi) / (2 * np.pi) + Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) + Xpt_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m)) projected_emd = wasserstein_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p) res = nx.mean(projected_emd) ** (1 / p) @@ -429,7 +429,7 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log # Projection on sphere Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True)) # Get coordinates on [0,1[ - Xps_coords = (nx.atan2(-Xps[:, :, 1], -Xps[:, :, 0]) + np.pi) / (2 * np.pi) + Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n)) projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a) res = nx.mean(projected_emd) ** (1 / 2) diff --git a/ot/utils.py b/ot/utils.py index 9093f096e..3423a7e4d 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -375,6 +375,36 @@ def check_random_state(seed): ' instance'.format(seed)) +def get_coordinate_circle(x): + r"""For :math:`x\in S^1 \subset \mathbb{R}^2`, returns the coordinates in + turn (in [0,1[). + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + Parameters + ---------- + x: ndarray, shape (n, 2) + Samples on the circle with ambient coordinates + + Returns + ------- + x_t: ndarray, shape (n,) + Coordinates on [0,1[ + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]]) * (2 * np.pi) + >>> x1, y1 = np.cos(u), np.sin(u) + >>> x = np.concatenate([x1, y1]).T + >>> get_coordinate_circle(x) + array([0.2, 0.5, 0.8]) + """ + nx = get_backend(x) + x_t = (nx.atan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi) + return x_t + + class deprecated(object): r"""Decorator to mark a function or class as deprecated. diff --git a/test/test_utils.py b/test/test_utils.py index 666c157c6..31b12efeb 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -330,3 +330,13 @@ def test_OTResult(): for at in lst_attributes: with pytest.raises(NotImplementedError): getattr(res, at) + + +def test_get_coordinate_circle(): + + u = np.random.rand(1, 100) + x1, y1 = np.cos(u * (2 * np.pi)), np.sin(u * (2 * np.pi)) + x = np.concatenate([x1, y1]).T + x_p = ot.utils.get_coordinate_circle(x) + + np.testing.assert_allclose(u[0], x_p)