Skip to content

Commit 5a3f367

Browse files
committed
Added ot.utils.get_coordinate_circle to get coordinates on the circle in turn
1 parent bb26cf7 commit 5a3f367

File tree

4 files changed

+51
-5
lines changed

4 files changed

+51
-5
lines changed

ot/lp/solver_1d.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,8 @@ def binary_search_circle(u_values, v_values, u_weights=None, v_weights=None, p=1
598598
.. math::
599599
u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
600600
601+
using e.g. ot.utils.get_coordinate_circle(x)
602+
601603
The function runs on backend but tensorflow is not supported.
602604
603605
Parameters
@@ -860,6 +862,8 @@ def wasserstein_circle(u_values, v_values, u_weights=None, v_weights=None, p=1,
860862
.. math::
861863
u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
862864
865+
using e.g. ot.utils.get_coordinate_circle(x)
866+
863867
The function runs on backend but tensorflow is not supported.
864868
865869
Parameters
@@ -931,7 +935,9 @@ def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None):
931935
For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with
932936
933937
.. math::
934-
u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
938+
u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi},
939+
940+
using e.g. ot.utils.get_coordinate_circle(x)
935941
936942
Parameters
937943
----------

ot/sliced.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import numpy as np
1414
from .backend import get_backend, NumpyBackend
15-
from .utils import list_to_array
15+
from .utils import list_to_array, get_coordinate_circle
1616
from .lp import wasserstein_circle, semidiscrete_wasserstein2_unif_circle
1717

1818

@@ -348,8 +348,8 @@ def sliced_wasserstein_sphere(X_s, X_t, a=None, b=None, n_projections=50,
348348
Xpt = Xpt / nx.sqrt(nx.sum(Xpt**2, -1, keepdims=True))
349349

350350
# Get coordinates on [0,1[
351-
Xps_coords = (nx.atan2(-Xps[:, :, 1], -Xps[:, :, 0]) + np.pi) / (2 * np.pi)
352-
Xpt_coords = (nx.atan2(-Xpt[:, :, 1], -Xpt[:, :, 0]) + np.pi) / (2 * np.pi)
351+
Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n))
352+
Xpt_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xpt, (-1, 2))), (n_projections, m))
353353

354354
projected_emd = wasserstein_circle(Xps_coords.T, Xpt_coords.T, u_weights=a, v_weights=b, p=p)
355355
res = nx.mean(projected_emd) ** (1 / p)
@@ -429,7 +429,7 @@ def sliced_wasserstein_sphere_unif(X_s, a=None, n_projections=50, seed=None, log
429429
# Projection on sphere
430430
Xps = Xps / nx.sqrt(nx.sum(Xps**2, -1, keepdims=True))
431431
# Get coordinates on [0,1[
432-
Xps_coords = (nx.atan2(-Xps[:, :, 1], -Xps[:, :, 0]) + np.pi) / (2 * np.pi)
432+
Xps_coords = nx.reshape(get_coordinate_circle(nx.reshape(Xps, (-1, 2))), (n_projections, n))
433433

434434
projected_emd = semidiscrete_wasserstein2_unif_circle(Xps_coords.T, u_weights=a)
435435
res = nx.mean(projected_emd) ** (1 / 2)

ot/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,36 @@ def check_random_state(seed):
375375
' instance'.format(seed))
376376

377377

378+
def get_coordinate_circle(x):
379+
r"""For :math:`x\in S^1 \subset \mathbb{R}^2`, returns the coordinates in
380+
turn (in [0,1[).
381+
382+
.. math::
383+
u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}
384+
385+
Parameters
386+
----------
387+
x: ndarray, shape (n, 2)
388+
Samples on the circle with ambient coordinates
389+
390+
Returns
391+
-------
392+
x_t: ndarray, shape (n,)
393+
Coordinates on [0,1[
394+
395+
Examples
396+
--------
397+
>>> u = np.array([[0.2,0.5,0.8]]) * (2 * np.pi)
398+
>>> x1, y1 = np.cos(u), np.sin(u)
399+
>>> x = np.concatenate([x1, y1]).T
400+
>>> get_coordinate_circle(x)
401+
array([0.2, 0.5, 0.8])
402+
"""
403+
nx = get_backend(x)
404+
x_t = (nx.atan2(-x[:, 1], -x[:, 0]) + np.pi) / (2 * np.pi)
405+
return x_t
406+
407+
378408
class deprecated(object):
379409
r"""Decorator to mark a function or class as deprecated.
380410

test/test_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,13 @@ def test_OTResult():
330330
for at in lst_attributes:
331331
with pytest.raises(NotImplementedError):
332332
getattr(res, at)
333+
334+
335+
def test_get_coordinate_circle():
336+
337+
u = np.random.rand(1, 100)
338+
x1, y1 = np.cos(u * (2 * np.pi)), np.sin(u * (2 * np.pi))
339+
x = np.concatenate([x1, y1]).T
340+
x_p = ot.utils.get_coordinate_circle(x)
341+
342+
np.testing.assert_allclose(u[0], x_p)

0 commit comments

Comments
 (0)