Skip to content

[MRG] Projection Robust Wasserstein #267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 6, 2021
Merged

Conversation

mhhuang95
Copy link
Contributor

Types of changes

  • Docs change / refactoring / dependency upgrade
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Motivation and context / Related issue

Code for
A Riemannian Block Coordinate Descent Method for Computing the PRW Distance, ICML 2021
Source: https://github.com/mhhuang95/PRW_RBCD

How has this been tested (if it applies)

tested on Fragmented Hypercube problem

Checklist

  • The documentation is up-to-date with the changes I made.
  • I have read the CONTRIBUTING document.
  • All tests passed, and additional code has been covered with new tests.

@codecov
Copy link

codecov bot commented Aug 5, 2021

Codecov Report

Merging #267 (21ce5b8) into master (c105dcb) will increase coverage by 0.07%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master     #267      +/-   ##
==========================================
+ Coverage   93.26%   93.34%   +0.07%     
==========================================
  Files          17       17              
  Lines        3506     3547      +41     
==========================================
+ Hits         3270     3311      +41     
  Misses        236      236              

@agramfort
Copy link
Collaborator

You have failing tests @mhhuang95

ot/dr.py Outdated
@@ -198,3 +199,118 @@ def proj(X):
return (X - mx.reshape((1, -1))).dot(Popt)

return Popt, proj


def prw(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the function name is not explicit enough. Can we call it: projection_robust_wasserstein? Maybe @rflamary has a better suggestion

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree projection_robust_wasserstein is longer but explicit, we will also need to change the wda and pca functions and add a deprecation on the old names.

ot/dr.py Outdated

def prw(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
r"""
Projection Robust Wasserstein Distance _[12],[13]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt this is formatted properly to have links to references below

ot/dr.py Outdated
Samples from measure \mu
Y : ndarray, shape (n, d)
Samples from measure \nu
a : ndarray, shape (n, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to have ndim == 2 with shape[1] == 1? I would favor a flat vector for a and b

ot/dr.py Outdated
k : int
Subspace dimension
stopThr : float, optional
Accuracy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accuracy is to me misleading. Maybe tolerance ? But on what criteria?

U = np.random.randn(d, k)
U, _ = np.linalg.qr(U)
else:
U = U0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line need to be covered by a test

@rflamary rflamary changed the title ot.dr: PRW code; text.text_dr: PRW test code. [WIP] Projection Robust Wasserstein Aug 9, 2021
Copy link
Collaborator

@agramfort agramfort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ot/dr.py Outdated
The function solves the following optimization problem:

.. math::
max_{U \in St(d, k)} min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j}*||U^T(x_i - y_j)||^2 - reg * H(\pi)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
max_{U \in St(d, k)} min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j}*||U^T(x_i - y_j)||^2 - reg * H(\pi)
\max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi)

ot/dr.py Outdated

- :math:`U` is a linear projection operator in the Stiefel(d, k) manifold
- :math:`H(\pi)` is entropy regularizer
- :math:`x_i`, `y_j` are samples of measures \mu and \nu respectively
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- :math:`x_i`, `y_j` are samples of measures \mu and \nu respectively
- :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively

@rflamary rflamary changed the title [WIP] Projection Robust Wasserstein [MRG] Projection Robust Wasserstein Sep 6, 2021
@rflamary
Copy link
Collaborator

rflamary commented Sep 6, 2021

Thank you again @mhhuang95 for this contribution and welcome to the POT contributors.

@rflamary rflamary merged commit 96bf1a4 into PythonOT:master Sep 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants