-
Notifications
You must be signed in to change notification settings - Fork 528
[MRG] Projection Robust Wasserstein #267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #267 +/- ##
==========================================
+ Coverage 93.26% 93.34% +0.07%
==========================================
Files 17 17
Lines 3506 3547 +41
==========================================
+ Hits 3270 3311 +41
Misses 236 236 |
You have failing tests @mhhuang95 |
ot/dr.py
Outdated
@@ -198,3 +199,118 @@ def proj(X): | |||
return (X - mx.reshape((1, -1))).dot(Popt) | |||
|
|||
return Popt, proj | |||
|
|||
|
|||
def prw(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the function name is not explicit enough. Can we call it: projection_robust_wasserstein? Maybe @rflamary has a better suggestion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree projection_robust_wasserstein is longer but explicit, we will also need to change the wda and pca functions and add a deprecation on the old names.
ot/dr.py
Outdated
|
||
def prw(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0): | ||
r""" | ||
Projection Robust Wasserstein Distance _[12],[13] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I doubt this is formatted properly to have links to references below
ot/dr.py
Outdated
Samples from measure \mu | ||
Y : ndarray, shape (n, d) | ||
Samples from measure \nu | ||
a : ndarray, shape (n, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to have ndim == 2 with shape[1] == 1? I would favor a flat vector for a and b
ot/dr.py
Outdated
k : int | ||
Subspace dimension | ||
stopThr : float, optional | ||
Accuracy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accuracy is to me misleading. Maybe tolerance ? But on what criteria?
U = np.random.randn(d, k) | ||
U, _ = np.linalg.qr(U) | ||
else: | ||
U = U0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line need to be covered by a test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
problems spotted by looking at https://724-71472695-gh.circle-artifacts.com/0/dev/gen_modules/ot.dr.html#ot.dr.projection_robust_wasserstein
ot/dr.py
Outdated
The function solves the following optimization problem: | ||
|
||
.. math:: | ||
max_{U \in St(d, k)} min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j}*||U^T(x_i - y_j)||^2 - reg * H(\pi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_{U \in St(d, k)} min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j}*||U^T(x_i - y_j)||^2 - reg * H(\pi) | |
\max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi) |
ot/dr.py
Outdated
|
||
- :math:`U` is a linear projection operator in the Stiefel(d, k) manifold | ||
- :math:`H(\pi)` is entropy regularizer | ||
- :math:`x_i`, `y_j` are samples of measures \mu and \nu respectively |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- :math:`x_i`, `y_j` are samples of measures \mu and \nu respectively | |
- :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively |
Thank you again @mhhuang95 for this contribution and welcome to the POT contributors. |
Types of changes
Motivation and context / Related issue
Code for
A Riemannian Block Coordinate Descent Method for Computing the PRW Distance, ICML 2021
Source: https://github.com/mhhuang95/PRW_RBCD
How has this been tested (if it applies)
tested on Fragmented Hypercube problem
Checklist