Skip to content

Free support barycenters #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 12, 2018
Merged

Conversation

vivienseguy
Copy link
Contributor

2-Wasserstein Barycenter algorithm with an example script. Optimization is carried over the locations of the support (not the weights). Only the unregularized case now. Regularized version will come soon.

@rflamary rflamary changed the title Vivien barycenters Free support barycenters Jul 6, 2018
Copy link
Collaborator

@rflamary rflamary left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @vivienseguy for all the work.

I have several comments that need to be addressed before merging (discussed more in detail below).

But most of all we need a test in the test_ot.py file that call your function and check stuff like the size of the output and reasonable solution.


##############################################################################
# Compute free support barycenter
# -------------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

---- needs to have the proper length for good documentation generation.

@@ -26,7 +27,7 @@ def scipy_sparse_to_spmatrix(A):


def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
"""Compute the entropic regularized wasserstein barycenter of distributions A
"""Compute the Wasserstein barycenter of distributions A
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch !

ot/lp/cvx.py Outdated

Parameters
----------
data_positions : list of (k_i,d) np.ndarray
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

names in the documentation different from the code : data_positions vs measures_locations

ot/lp/cvx.py Outdated
Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing log parameter in the function.

would be nice to return the list of the displacement_square_norm along the iteration in a dictionnary if log=True (similar behavior as barycenter function above that retruns a log)

ot/lp/cvx.py Outdated

iter_count += 1

return X
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add log if log=True

ot/lp/cvx.py Outdated
@@ -10,6 +10,7 @@
import numpy as np
import scipy as sp
import scipy.sparse as sps
import ot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you shouldn't import pot inside a module.

something with relative path like

from .__init__ import emd

is far better since it imports the emd function from the __init__.py

ot/lp/cvx.py Outdated
@@ -144,3 +145,83 @@ def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-po
return b, sol
else:
return b


def free_support_barycenter(measures_locations, measures_weights, X_init, b, weights=None, numItermax=100, stopThr=1e-6, verbose=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also do b=None if the weights are supposed uniform (needs test an initialization in the function)

X_init = np.random.normal(0., 1., (k, d))
b = np.ones((k,)) / k

X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ot.lp.cvx.free_support_barycenter is very long.

you should import the function in ot.lp __init__.py and add it to __all__ like barycenter so that you can do ot.lp.free_support_barycenter

@rflamary rflamary merged commit 7c5c880 into PythonOT:master Jul 12, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants