-
Notifications
You must be signed in to change notification settings - Fork 528
[MRG] Adding greenkhorn #66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
653fd00
eb17e02
eb7a395
ff824a2
f3433fd
7ffd4fe
24a53ef
55e8392
1d49410
75fe96c
414331c
dee6d6e
8f908bd
1b24b1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,7 +47,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, | |
reg : float | ||
Regularization term >0 | ||
method : str | ||
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or | ||
method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or | ||
'sinkhorn_epsilon_scaling', see those function for specific parameters | ||
numItermax : int, optional | ||
Max number of iterations | ||
|
@@ -103,6 +103,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, | |
def sink(): | ||
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, | ||
stopThr=stopThr, verbose=verbose, log=log, **kwargs) | ||
if method.lower() == 'greenkhorn': | ||
def sink(): | ||
return greenkhorn(a, b, M, reg, numItermax=numItermax, | ||
stopThr=stopThr, verbose=verbose, log=log) | ||
elif method.lower() == 'sinkhorn_stabilized': | ||
def sink(): | ||
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, | ||
|
@@ -197,13 +201,16 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, | |
|
||
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. | ||
|
||
[21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 | ||
|
||
|
||
|
||
See Also | ||
-------- | ||
ot.lp.emd : Unregularized OT | ||
ot.optim.cg : General regularized OT | ||
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2] | ||
ot.bregman.greenkhorn : Greenkhorn [21] | ||
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10] | ||
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10] | ||
|
||
|
@@ -410,6 +417,148 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, | |
return u.reshape((-1, 1)) * K * v.reshape((1, -1)) | ||
|
||
|
||
|
||
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log = False): | ||
""" | ||
Solve the entropic regularization optimal transport problem and return the OT matrix | ||
|
||
The algorithm used is based on the paper | ||
|
||
Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration | ||
by Jason Altschuler, Jonathan Weed, Philippe Rigollet | ||
appeared at NIPS 2017 | ||
|
||
which is a stochastic version of the Sinkhorn-Knopp algorithm [2]. | ||
|
||
The function solves the following optimization problem: | ||
|
||
.. math:: | ||
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma) | ||
|
||
s.t. \gamma 1 = a | ||
|
||
\gamma^T 1= b | ||
|
||
\gamma\geq 0 | ||
where : | ||
|
||
- M is the (ns,nt) metric cost matrix | ||
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` | ||
- a and b are source and target weights (sum to 1) | ||
|
||
|
||
|
||
Parameters | ||
---------- | ||
a : np.ndarray (ns,) | ||
samples weights in the source domain | ||
b : np.ndarray (nt,) or np.ndarray (nt,nbb) | ||
samples in the target domain, compute sinkhorn with multiple targets | ||
and fixed M if b is a matrix (return OT loss + dual variables in log) | ||
M : np.ndarray (ns,nt) | ||
loss matrix | ||
reg : float | ||
Regularization term >0 | ||
numItermax : int, optional | ||
Max number of iterations | ||
stopThr : float, optional | ||
Stop threshol on error (>0) | ||
log : bool, optional | ||
record log if True | ||
|
||
|
||
Returns | ||
------- | ||
gamma : (ns x nt) ndarray | ||
Optimal transportation matrix for the given parameters | ||
log : dict | ||
log dictionary return only if log==True in parameters | ||
|
||
Examples | ||
-------- | ||
|
||
>>> import ot | ||
>>> a=[.5,.5] | ||
>>> b=[.5,.5] | ||
>>> M=[[0.,1.],[1.,0.]] | ||
>>> ot.sinkhorn(a,b,M,1) | ||
array([[ 0.36552929, 0.13447071], | ||
[ 0.13447071, 0.36552929]]) | ||
|
||
|
||
References | ||
---------- | ||
|
||
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 | ||
[21] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017 | ||
|
||
|
||
See Also | ||
-------- | ||
ot.lp.emd : Unregularized OT | ||
ot.optim.cg : General regularized OT | ||
|
||
""" | ||
|
||
i = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indeed, thougt I got rid of all of them... |
||
|
||
n = a.shape[0] | ||
m = b.shape[0] | ||
|
||
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute | ||
K = np.empty(M.shape, dtype=M.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. K = np.empty_like(M) |
||
np.divide(M, -reg, out=K) | ||
np.exp(K, out=K) | ||
|
||
u = np.ones(n)/n | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. np.full(n, 1. / n) |
||
v = np.ones(m)/m | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. np.full(m, 1. / m) |
||
G = np.diag(u)@K@np.diag(v) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use broadcasting to avoid filling diagonal matrices
|
||
|
||
one_n = np.ones(n) | ||
one_m = np.ones(m) | ||
viol = G@one_m - a | ||
viol_2 = G.T@one_n - b | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here to allocate arrays of ones to compute sum of rows and columns. I would just use np.sum(..., axis=) |
||
stopThr_val = 1 | ||
if log: | ||
log['u'] = u | ||
log['v'] = v | ||
|
||
while i < numItermax and stopThr_val > stopThr: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rather than using a while you could use a for loop. For optim solvers I tend to do:
so you can easily print a message when you did not converge. |
||
i +=1 | ||
i_1 = np.argmax(np.abs(viol)) | ||
i_2 = np.argmax(np.abs(viol_2)) | ||
m_viol_1 = np.abs(viol[i_1]) | ||
m_viol_2 = np.abs(viol_2[i_2]) | ||
stopThr_val = np.maximum(m_viol_1,m_viol_2) | ||
|
||
if m_viol_1 > m_viol_2: | ||
old_u = u[i_1] | ||
u[i_1] = a[i_1]/(K[i_1,:]@v) | ||
G[i_1,:] = u[i_1]*K[i_1,:]*v | ||
|
||
viol[i_1] = u[i_1]*K[i_1,:]@v - a[i_1] | ||
viol_2 = viol_2 + ( K[i_1,:].T*(u[i_1] - old_u)*v) | ||
|
||
else: | ||
old_v = v[i_2] | ||
v[i_2] = b[i_2]/(K[:,i_2].T@u) | ||
G[:,i_2] = u*K[:,i_2]*v[i_2] | ||
#aviol = (G@one_m - a) | ||
#aviol_2 = (G.T@one_n - b) | ||
viol = viol + ( -old_v + v[i_2])*K[:,i_2]*u | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. viol += ... |
||
viol_2[i_2] = v[i_2]*K[:,i_2]@u - b[i_2] | ||
|
||
#print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2))) | ||
|
||
if log: | ||
log['u'] = u | ||
log['v'] = v | ||
|
||
if log: | ||
return G,log | ||
else: | ||
return G | ||
|
||
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, | ||
warmstart=None, verbose=False, print_period=20, log=False, **kwargs): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,12 +71,14 @@ def test_sinkhorn_variants(): | |
Ges = ot.sinkhorn( | ||
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10) | ||
Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10) | ||
G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10) | ||
|
||
# check values | ||
np.testing.assert_allclose(G0, Gs, atol=1e-05) | ||
np.testing.assert_allclose(G0, Ges, atol=1e-05) | ||
np.testing.assert_allclose(G0, Gerr) | ||
|
||
np.testing.assert_allclose(G0, G_green, atol = 1e-32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a pep8 checker would tell you but you should not put spaces around = in function signatures. It's to visually distinguish what is a function parameter from a variable assignment. |
||
print(G0,G_green) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and you should always put a space after a , |
||
|
||
def test_bary(): | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ot.greenkhorn