Skip to content

[MRG] Wasserstein distance on the circle and Spherical Sliced-Wasserstein #434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 23, 2023

Conversation

clbonet
Copy link
Contributor

@clbonet clbonet commented Feb 9, 2023

Types of changes

New functions in solver_1d.py:

  • binary_search_circle: binary search to compute the p-Wasserstein distance on the circle (implementation of the algorithm of Fast transport optimization for Monge costs on the circle).
  • wasserstein1_circle: 1-Wasserstein distance on the circle using the closed-form with the level median as proposed in The statistics of circular optimal transport.
  • wasserstein_circle: Compute the p-Wasserstein distance on the circle using the closed-form for p=1 and the binary search otherwise.
  • wasserstein2_unif_circle: Closed-form of the 2-Wasserstein distance between samples and a uniform distribution
  • Cost: Compute the p-Wasserstein objective (used in the binary search)
  • dCost: Compute the left and right derivative of the p-Wasserstein objective (used in the binary search)
  • roll_cols: a util function to roll each row

New function in sliced.py:

  • sliced_wasserstein_sphere: Compute the Spherical Sliced-Wasserstein discrepancy
  • sliced_wasserstein_sphere_unif: SSW_2 between samples and a uniform distribution using the closed-form of the Wasserstein distance

Modifications of backend.py:

  • Added some methods in the backend: tile, floor, prod, sort2, qr, atan2, transpose

New function in utils.py:

  • get_coordinate_circle: Get the coordinates in turn (in [0,1[) from the points on the circle in the ambient space

New examples:

  • plot_compute_wasserstein_circle.py: Example of using ot.wasserstein_circle between von Mises distributions, and ot.wasserstein2_unif_circle
  • plot_variance_ssw.py: Example of using ot.sliced_wasserstein_sphere
  • plot_ssw_unif_torch.py: Gradient descent on particles to learn a uniform distribution

New tests in test_1d_solver.py:

  • test_wasserstein_1d_circle: Compare the values when computing the transport with ot.emd2 and the geodesic distance as ground cost
  • test_wasserstein1d_circle_devices: Test devices (except for tensorflow)
  • test_wasserstein_1d_unif_circle: Compare the values of ot.wasserstein2_unif_circle and an approximation with ot.emd2 or ot.wasserstein_circle
  • test_wasserstein1d_unif_circle_devices: Test devices

New tests in test_sliced.py:

  • test_projections_stiefel: Check that the projections are well on the Stiefel manifold
  • test_sliced_sphere_same_dist: Check that SSW(x,x)=0
  • test_sliced_sphere_bad_shapes: Check the error when the shapes are differents
  • test_sliced_sphere: Check the error when the samples are not on the sphere
  • test_sliced_sphere_log: Check that log returns the projections and the Wasserstein distance between the projected samples
  • test_sliced_sphere_different_dists: Check that SSW is not equal to 0 between different samples
  • test_1d_sliced_sphere_equals_emd: Check that we recover the Wasserstein distance on the circle when samples are in S^1
  • test_sliced_sphere_backend_type_devices: Test devices (except for tensorflow)
  • test_sliced_sphere_unif: Check the error when the samples are not on the sphere
  • test_sliced_sphere_unif_log: Check that log returns the projections and the Wasserstein distance between the projected samples
  • test_sliced_sphere_unif_backend_type_devices: Test devices

New tests in test_backend.py:

  • Added the tests for tile, floor, prod, sort2, qr, atan2 and transpose

New test in test_utils.py:

  • test_get_coordinate_circle: Test that we recover the right coordinates

Motivation and context / Related issue

The Wasserstein distance on the circle can be computed more efficiently using its "closed-forms". For p=1, a closed-form was proposed in The statistics of circular optimal transport, and for p>=1, a binary search algorithm was proposed in Fast transport optimization for Monge costs on the circle.

For data lying on the sphere, we can use a particular sliced-Wasserstein discrepancy by using geodesic projections on the circle.

How has this been tested (if it applies)

  • The computation of Wasserstein on the sphere has been compared with ot.emd2 with geodesic ground cost. For the p=1 implementation with the level median, the results seem a bit less precise than with the binary search (the rtol is reduced in the test) but it is quicker.
  • The spherical sliced-Wasserstein was tested with the same type of test as sliced-Wasserstein
  • I did not test with tensorflow

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@codecov
Copy link

codecov bot commented Feb 13, 2023

Codecov Report

Merging #434 (5a3f367) into master (97feeb3) will increase coverage by 0.25%.
The diff coverage is 99.43%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #434      +/-   ##
==========================================
+ Coverage   94.43%   94.69%   +0.25%     
==========================================
  Files          24       24              
  Lines        6254     6593     +339     
==========================================
+ Hits         5906     6243     +337     
- Misses        348      350       +2     

Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @clbonet

This is an impressive PR! I have found a few comments that need to be taken into account (or at leest discussed). Could you look into that please?

@rflamary rflamary changed the title [WIP] Wasserstein distance on the circle and Spherical Sliced-Wasserstein [MRG] Wasserstein distance on the circle and Spherical Sliced-Wasserstein Feb 22, 2023
@rflamary rflamary merged commit 80e3c23 into PythonOT:master Feb 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants