Skip to content

[MRG] EMD and Wasserstein 1D #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 27, 2019
Merged

[MRG] EMD and Wasserstein 1D #89

merged 14 commits into from
Jun 27, 2019

Conversation

rtavenar
Copy link
Contributor

Hi there,

I started coding a specific EMD for mono-dimensional case (i.e. when sorting both arrays is enough).
Doc is missing for the moment (will do that asap), but a basic implementation that covers the non uniform weight case and tests that checks if the results are coherent with EMD are already there.

On my machine, I ran the following timing test:

>>> n = 20000
>>> m = 3000
>>> u = np.random.randn(n, 1)
>>> v = np.random.randn(m, 1)
>>> ot.tic(); ot.emd_1d([], [], u, v, metric='sqeuclidean'); ot.toc()
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])
Elapsed time : 2.3728668689727783 s
2.3728668689727783
>>> ot.tic(); M = ot.dist(u, v, metric='sqeuclidean'); ot.emd([], [], M); ot.toc()
RESULT MIGHT BE INACURATE
Max number of iteration reached, currently 100000. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher
/Users/tavenard_r/Documents/costel/src/POT/ot/lp/__init__.py:104: UserWarning: numItermax reached before optimality. Try to increase numItermax.
  result_code_string = check_result(result_code)
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])
Elapsed time : 8.67806887626648 s

Romain

np.ndarray[double, ndim=2, mode="c"] M):
np.ndarray[double, ndim=2, mode="c"] u,
np.ndarray[double, ndim=2, mode="c"] v,
str metric='sqeuclidean'):
r"""
Roro's stuff
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice documentation indeed ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:-P

@rflamary
Copy link
Collaborator

Thank you romain, this is nice.

Is it me of Cython is not particularly fast (2 sec for n=20000?)? it is probably due to the use of the dist function, you should probably implement it in cython for squared and absolute value and use the dist only for weird stuff ;)

Rémi.

@rtavenar
Copy link
Contributor Author

rtavenar commented Jun 20, 2019

If I change to the following:

        if metric == 'sqeuclidean':
            m_ij = (u[i, 0] - v[j, 0]) ** 2
        else:
            m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
                        metric=metric)[0, 0]

I get the same timings (in the order of 2secs)...

The slow part seems to be when we deal with G. If I remove all the stuff related to G I get:

Elapsed time : 0.0061719417572021484 s

I will check if using a sparse representation for G helps.

EDIT: OK, when I remove the overhead for G, I can see a 100x improvement in timings with this if...else, so will keep it for L1 and L2 norms and resort to dist for other distances.

dtype=np.float64)
while i < n and j < m:
m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
metric=metric)[0, 0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you have a pure python function call in the loop I doubt that cython brings you any speed gain.

my 2c

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried something for basic metrics (euclidean and sqeuclidean), not sure how to do otherwise

@rtavenar
Copy link
Contributor Author

rtavenar commented Jun 21, 2019

Also, new timings (for larger problem than above) are:

>>> import ot
>>> import numpy as np
>>> from scipy.stats import wasserstein_distance
>>> 
>>> n = 20000
>>> m = 30000
>>> u = np.random.randn(n)
>>> v = np.random.randn(m)
>>> 
>>> ot.tic(); _ = wasserstein_distance(u, v); _ = ot.toc()
Elapsed time : 0.012831926345825195 s
>>> ot.tic(); _ = ot.emd_1d([], [], u, v, metric='euclidean', dense=False); _ = ot.toc()
Elapsed time : 0.04144096374511719 s
>>> ot.tic(); M = ot.dist(u.reshape((-1, 1)), v.reshape((-1, 1)),
...                       metric='euclidean'); _ = ot.emd([], [], M); _ = ot.toc()

RESULT MIGHT BE INACURATE
Max number of iteration reached, currently 100000. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher
/Users/tavenard_r/Documents/costel/src/POT/ot/lp/__init__.py:106: UserWarning: numItermax reached before optimality. Try to increase numItermax.
  result_code_string = check_result(result_code)
Elapsed time : 312.5033311843872 s

We are a bit slower than scipy's implementation, not sure whether this is due to Cython or to the fact that scipy does not deal with G :/

@agramfort
Copy link
Collaborator

@rtavenar did you run "cython -a" on the pyx file to see if it's white (no yellow slow python lines)?

@rtavenar
Copy link
Contributor Author

@agramfort

  1. Did not know this one, thanks for the tip
  2. I've changed the np.abs, now the only yellow lines I get are the return, the np.zeros lines and the cdist, but I do not know how to remove these ones.

@agramfort
Copy link
Collaborator

you cannot remove yellow lines of np.zeros or return.

For cdist either you can directly call blas functions from scipy or you need
to code the metrics in cython

@rtavenar
Copy link
Contributor Author

I've had a look there, could not find obvious matches for distances, but maybe it's not the right place :/

Regarding coding the metrics in Cython, this is what I have done for Euclidean distance and Squared Euclidean distance up to now. The question is: should I code all of them even if they are unlikely to be used, or only a subset?

@rflamary
Copy link
Collaborator

hello, I think those two are OK just be clear in the documentation that the others are slower and use cdist (such a slow function btw ;) )

Rémi

@rtavenar
Copy link
Contributor Author

OK, and I'll also have to be clear that only strings are accepted as metrics for emd_1d

@rtavenar
Copy link
Contributor Author

OK, so now I added proper docstrings. Let me know if something is missing or should be changed.

@rflamary
Copy link
Collaborator

This is great, thank you @rtavenar for the code and optimization.

I will merge it now.

@rflamary rflamary changed the title [WIP] EMD 1d [MRG] EMD and Wassersyein 1D Jun 27, 2019
@rflamary rflamary changed the title [MRG] EMD and Wassersyein 1D [MRG] EMD and Wasserstein 1D Jun 27, 2019
@rflamary rflamary merged commit a9b8af1 into PythonOT:master Jun 27, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants