Skip to content
151 changes: 150 additions & 1 deletion ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
reg : float
Regularization term >0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
'sinkhorn_epsilon_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
Expand Down Expand Up @@ -103,6 +103,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
def sink():
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
if method.lower() == 'greenkhorn':
def sink():
return greenkhorn(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log)
elif method.lower() == 'sinkhorn_stabilized':
def sink():
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
Expand Down Expand Up @@ -197,13 +201,16 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,

.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.

[21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017



See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
ot.bregman.greenkhorn : Greenkhorn [21]
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]

Expand Down Expand Up @@ -410,6 +417,148 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))



def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log = False):
"""
Solve the entropic regularization optimal transport problem and return the OT matrix

The algorithm used is based on the paper

Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
by Jason Altschuler, Jonathan Weed, Philippe Rigollet
appeared at NIPS 2017

which is a stochastic version of the Sinkhorn-Knopp algorithm [2].

The function solves the following optimization problem:

.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)

s.t. \gamma 1 = a

\gamma^T 1= b

\gamma\geq 0
where :

- M is the (ns,nt) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)



Parameters
----------
a : np.ndarray (ns,)
samples weights in the source domain
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
M : np.ndarray (ns,nt)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
log : bool, optional
record log if True


Returns
-------
gamma : (ns x nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters

Examples
--------

>>> import ot
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.sinkhorn(a,b,M,1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ot.greenkhorn

array([[ 0.36552929, 0.13447071],
[ 0.13447071, 0.36552929]])


References
----------

.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
[21] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017


See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT

"""

i = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, thougt I got rid of all of them...


n = a.shape[0]
m = b.shape[0]

# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K = np.empty_like(M)

np.divide(M, -reg, out=K)
np.exp(K, out=K)

u = np.ones(n)/n
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.full(n, 1. / n)

v = np.ones(m)/m
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.full(m, 1. / m)

G = np.diag(u)@K@np.diag(v)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use broadcasting to avoid filling diagonal matrices

G = u[:, np.newaxis] * K * v[, np.newaxis]


one_n = np.ones(n)
one_m = np.ones(m)
viol = G@one_m - a
viol_2 = G.T@one_n - b
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here to allocate arrays of ones to compute sum of rows and columns. I would just use np.sum(..., axis=)

stopThr_val = 1
if log:
log['u'] = u
log['v'] = v

while i < numItermax and stopThr_val > stopThr:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather than using a while you could use a for loop. For optim solvers I tend to do:

for i in range(numItermax):
      ...
      if stopping condition satisfied do:
              break
else:
     print("Solver did not converge")

so you can easily print a message when you did not converge.

i +=1
i_1 = np.argmax(np.abs(viol))
i_2 = np.argmax(np.abs(viol_2))
m_viol_1 = np.abs(viol[i_1])
m_viol_2 = np.abs(viol_2[i_2])
stopThr_val = np.maximum(m_viol_1,m_viol_2)

if m_viol_1 > m_viol_2:
old_u = u[i_1]
u[i_1] = a[i_1]/(K[i_1,:]@v)
G[i_1,:] = u[i_1]*K[i_1,:]*v

viol[i_1] = u[i_1]*K[i_1,:]@v - a[i_1]
viol_2 = viol_2 + ( K[i_1,:].T*(u[i_1] - old_u)*v)

else:
old_v = v[i_2]
v[i_2] = b[i_2]/(K[:,i_2].T@u)
G[:,i_2] = u*K[:,i_2]*v[i_2]
#aviol = (G@one_m - a)
#aviol_2 = (G.T@one_n - b)
viol = viol + ( -old_v + v[i_2])*K[:,i_2]*u
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

viol += ...

viol_2[i_2] = v[i_2]*K[:,i_2]@u - b[i_2]

#print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))

if log:
log['u'] = u
log['v'] = v

if log:
return G,log
else:
return G

def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
"""
Expand Down
4 changes: 3 additions & 1 deletion test/test_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,14 @@ def test_sinkhorn_variants():
Ges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)

# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, Gerr)

np.testing.assert_allclose(G0, G_green, atol = 1e-32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a pep8 checker would tell you but you should not put spaces around = in function signatures. It's to visually distinguish what is a function parameter from a variable assignment.

print(G0,G_green)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and you should always put a space after a ,


def test_bary():

Expand Down