-
Notifications
You must be signed in to change notification settings - Fork 528
Free support barycenters #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @vivienseguy for all the work.
I have several comments that need to be addressed before merging (discussed more in detail below).
But most of all we need a test in the test_ot.py file that call your function and check stuff like the size of the output and reasonable solution.
|
||
############################################################################## | ||
# Compute free support barycenter | ||
# ------------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
----
needs to have the proper length for good documentation generation.
@@ -26,7 +27,7 @@ def scipy_sparse_to_spmatrix(A): | |||
|
|||
|
|||
def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'): | |||
"""Compute the entropic regularized wasserstein barycenter of distributions A | |||
"""Compute the Wasserstein barycenter of distributions A |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch !
ot/lp/cvx.py
Outdated
|
||
Parameters | ||
---------- | ||
data_positions : list of (k_i,d) np.ndarray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
names in the documentation different from the code : data_positions vs measures_locations
ot/lp/cvx.py
Outdated
Stop threshol on error (>0) | ||
verbose : bool, optional | ||
Print information along iterations | ||
log : bool, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing log parameter in the function.
would be nice to return the list of the displacement_square_norm along the iteration in a dictionnary if log=True (similar behavior as barycenter function above that retruns a log)
ot/lp/cvx.py
Outdated
|
||
iter_count += 1 | ||
|
||
return X |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add log if log=True
ot/lp/cvx.py
Outdated
@@ -10,6 +10,7 @@ | |||
import numpy as np | |||
import scipy as sp | |||
import scipy.sparse as sps | |||
import ot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you shouldn't import pot inside a module.
something with relative path like
from .__init__ import emd
is far better since it imports the emd function from the __init__.py
ot/lp/cvx.py
Outdated
@@ -144,3 +145,83 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po | |||
return b, sol | |||
else: | |||
return b | |||
|
|||
|
|||
def free_support_barycenter(measures_locations, measures_weights, X_init, b, weights=None, numItermax=100, stopThr=1e-6, verbose=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also do b=None if the weights are supposed uniform (needs test an initialization in the function)
X_init = np.random.normal(0., 1., (k, d)) | ||
b = np.ones((k,)) / k | ||
|
||
X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ot.lp.cvx.free_support_barycenter
is very long.
you should import the function in ot.lp __init__.py
and add it to __all__
like barycenter so that you can do ot.lp.free_support_barycenter
2-Wasserstein Barycenter algorithm with an example script. Optimization is carried over the locations of the support (not the weights). Only the unregularized case now. Regularized version will come soon.