-
Notifications
You must be signed in to change notification settings - Fork 528
Stochastic ot #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stochastic ot #62
Conversation
2785f5c
to
436b228
Compare
abd9952
to
b13feb0
Compare
adf716c
to
cd193f7
Compare
5b8ffa6
to
37e3b29
Compare
199f4fb
to
e508779
Compare
Hello @kilianFatras, Thank you for the PR. Please provide a description of what this PR is doing. If it corrects a bug, add a test (commented) that check that this bug will no reappear in the future. If you have computational gain, an example of script with time before and after the PR would be nice also. |
1dfae87
to
b2b5ffc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @kilianFatras,
here are a few small comments and things to change before the merge.
ot/lp/__init__.py
Outdated
@@ -21,6 +21,8 @@ | |||
|
|||
__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx'] | |||
|
|||
__all__=['emd', 'emd2', 'barycenter', 'cvx'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete this line, free_support_barycenter is a new function that has to be imported
test/test_stochastic.py
Outdated
@@ -184,8 +185,34 @@ def test_dual_sgd_sinkhorn(): | |||
|
|||
# check constratints | |||
np.testing.assert_allclose( | |||
zero, (G_sgd - G_sinkhorn).sum(1), atol=1e-02) # cf convergence sgd | |||
zero, (G_sgd - G_sinkhorn).sum(1), atol=1e-04) # cf convergence sgd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test makes no sens since you should sum the absolute value to test if error and not just the difference:
np.testing.assert_allclose(G_sgd.sum(1),G_sinkhorn.sum(1), atol=1e-02)
same for sum(0)
74c7f5d
to
698c1aa
Compare
3757826
to
fd6371c
Compare
59f9163
to
63b34bf
Compare
This PR fixes a bug in the SGD dual solver function. This function is part of the stochastic semi dual and dual solver framework. There is also a speed up of the function by removing one for loop.