Skip to content

Commit 7ba5f03

Browse files
committed
Allow warmstart in sinkhorn and sinkhorn_log
1 parent 97feeb3 commit 7ba5f03

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

ot/bregman.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
364364
raise ValueError("Unknown method '%s'." % method)
365365

366366

367-
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
367+
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None,
368368
verbose=False, log=False, warn=True,
369369
**kwargs):
370370
r"""
@@ -409,6 +409,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
409409
Max number of iterations
410410
stopThr : float, optional
411411
Stop threshold on error (>0)
412+
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
413+
Initialization of dual vectors. If provided, the dual vectors must be in logarithm form,
414+
i.e. warmstart = (log_u, log_v), but not (u, v).
412415
verbose : bool, optional
413416
Print information along iterations
414417
log : bool, optional
@@ -474,12 +477,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
474477

475478
# we assume that no distances are null except those of the diagonal of
476479
# distances
477-
if n_hists:
478-
u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
479-
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
480+
if warmstart is None:
481+
if n_hists:
482+
u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
483+
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
484+
else:
485+
u = nx.ones(dim_a, type_as=M) / dim_a
486+
v = nx.ones(dim_b, type_as=M) / dim_b
480487
else:
481-
u = nx.ones(dim_a, type_as=M) / dim_a
482-
v = nx.ones(dim_b, type_as=M) / dim_b
488+
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
483489

484490
K = nx.exp(M / (-reg))
485491

@@ -546,7 +552,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
546552
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
547553

548554

549-
def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
555+
def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None, verbose=False,
550556
log=False, warn=True, **kwargs):
551557
r"""
552558
Solve the entropic regularization optimal transport problem in log space
@@ -590,6 +596,9 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
590596
Max number of iterations
591597
stopThr : float, optional
592598
Stop threshold on error (>0)
599+
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
600+
Initialization of dual vectors. If provided, the dual vectors must be in logarithm form,
601+
i.e. warmstart = (log_u, log_v), but not (u, v).
593602
verbose : bool, optional
594603
Print information along iterations
595604
log : bool, optional
@@ -656,14 +665,18 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
656665
else:
657666
n_hists = 0
658667

668+
# in case of multiple historgrams
669+
if n_hists > 1 and warmstart is None:
670+
warmstart = [None] * n_hists
671+
659672
if n_hists: # we do not want to use tensors sor we do a loop
660673

661674
lst_loss = []
662675
lst_u = []
663676
lst_v = []
664677

665678
for k in range(n_hists):
666-
res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax,
679+
res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, warmstart=warmstart[k],
667680
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
668681

669682
if log:
@@ -691,9 +704,11 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
691704

692705
# we assume that no distances are null except those of the diagonal of
693706
# distances
694-
695-
u = nx.zeros(dim_a, type_as=M)
696-
v = nx.zeros(dim_b, type_as=M)
707+
if warmstart is None:
708+
u = nx.zeros(dim_a, type_as=M)
709+
v = nx.zeros(dim_b, type_as=M)
710+
else:
711+
u, v = warmstart
697712

698713
def get_logT(u, v):
699714
if n_hists:

0 commit comments

Comments
 (0)