-
Notifications
You must be signed in to change notification settings - Fork 528
Description
Describe the bug
ot.bregman.empirical_sinkhorn_divergence
has (as most functions in the module) a parameter stopThr
to control the desired marginal error, with default value 1e-9
. However, this functions calls empirical_sinkhorn2
with stopThr=1e-9
(hardcoded), so that the parameter stopThr
of empirical_sinkhorn_divergence
has no effect.
Code sample
import numpy as np
import ot
n, m = 10, 10
X = np.random.randn(n, 2)
Y = np.random.randn(m, 2)
ot.bregman.empirical_sinkhorn_divergence(X, Y, reg=1, verbose=True, stopThr=1e-2)
Outputs:
It. |Err
-------------------
0|1.672739e-01|
10|1.323828e-02|
20|1.179009e-03|
30|1.250032e-04|
40|2.445311e-05|
50|6.149783e-06|
60|1.594631e-06|
70|4.145982e-07|
80|1.078188e-07|
90|2.803949e-08|
100|7.292000e-09|
110|1.896371e-09|
Expected behavior
The iterations should stop when Err < stopThr
, not 1e-9
.
Additional information
A similar issue seems to occur in sinkhorn_epsilon_scaling
where stopThr=1e-9
(hardcoded) in the subsequent call to sinkhorn_stabilized(..., stopThr=1e-9)
.
Fix proposition
Unless this behavior is on purpose (in which case it may be useful to document it), a reasonable fix is to change stopThr=1e-9
to stopThr=stopThr
in the subcalls (6 instances in empirical_sinkhorn_divergence
, 1 in sinkhorn_epsilon_scaling
, as far as I can tell).
I can take care of the fix PR if this is a good solution (or let it to someone used to contribute to POT who can do that in 2 min).