-
Notifications
You must be signed in to change notification settings - Fork 528
[MRG] EMD and Wasserstein 1D #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ot/lp/emd_wrap.pyx
Outdated
np.ndarray[double, ndim=2, mode="c"] M): | ||
np.ndarray[double, ndim=2, mode="c"] u, | ||
np.ndarray[double, ndim=2, mode="c"] v, | ||
str metric='sqeuclidean'): | ||
r""" | ||
Roro's stuff |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice documentation indeed ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:-P
Thank you romain, this is nice. Is it me of Cython is not particularly fast (2 sec for n=20000?)? it is probably due to the use of the dist function, you should probably implement it in cython for squared and absolute value and use the dist only for weird stuff ;) Rémi. |
If I change to the following: if metric == 'sqeuclidean':
m_ij = (u[i, 0] - v[j, 0]) ** 2
else:
m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)),
metric=metric)[0, 0] I get the same timings (in the order of 2secs)... The slow part seems to be when we deal with
I will check if using a sparse representation for EDIT: OK, when I remove the overhead for |
ot/lp/emd_wrap.pyx
Outdated
dtype=np.float64) | ||
while i < n and j < m: | ||
m_ij = dist(u[i].reshape((1, 1)), v[j].reshape((1, 1)), | ||
metric=metric)[0, 0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since you have a pure python function call in the loop I doubt that cython brings you any speed gain.
my 2c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried something for basic metrics (euclidean
and sqeuclidean
), not sure how to do otherwise
Also, new timings (for larger problem than above) are: >>> import ot
>>> import numpy as np
>>> from scipy.stats import wasserstein_distance
>>>
>>> n = 20000
>>> m = 30000
>>> u = np.random.randn(n)
>>> v = np.random.randn(m)
>>>
>>> ot.tic(); _ = wasserstein_distance(u, v); _ = ot.toc()
Elapsed time : 0.012831926345825195 s
>>> ot.tic(); _ = ot.emd_1d([], [], u, v, metric='euclidean', dense=False); _ = ot.toc()
Elapsed time : 0.04144096374511719 s
>>> ot.tic(); M = ot.dist(u.reshape((-1, 1)), v.reshape((-1, 1)),
... metric='euclidean'); _ = ot.emd([], [], M); _ = ot.toc()
RESULT MIGHT BE INACURATE
Max number of iteration reached, currently 100000. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher
/Users/tavenard_r/Documents/costel/src/POT/ot/lp/__init__.py:106: UserWarning: numItermax reached before optimality. Try to increase numItermax.
result_code_string = check_result(result_code)
Elapsed time : 312.5033311843872 s We are a bit slower than |
@rtavenar did you run "cython -a" on the pyx file to see if it's white (no yellow slow python lines)? |
|
you cannot remove yellow lines of np.zeros or return. For cdist either you can directly call blas functions from scipy or you need |
I've had a look there, could not find obvious matches for distances, but maybe it's not the right place :/ Regarding coding the metrics in Cython, this is what I have done for Euclidean distance and Squared Euclidean distance up to now. The question is: should I code all of them even if they are unlikely to be used, or only a subset? |
hello, I think those two are OK just be clear in the documentation that the others are slower and use cdist (such a slow function btw ;) ) Rémi |
OK, and I'll also have to be clear that only strings are accepted as metrics for |
OK, so now I added proper docstrings. Let me know if something is missing or should be changed. |
This is great, thank you @rtavenar for the code and optimization. I will merge it now. |
Hi there,
I started coding a specific EMD for mono-dimensional case (i.e. when sorting both arrays is enough).
Doc is missing for the moment (will do that asap), but a basic implementation that covers the non uniform weight case and tests that checks if the results are coherent with EMD are already there.
On my machine, I ran the following timing test:
Romain