@@ -364,7 +364,7 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
364
364
raise ValueError ("Unknown method '%s'." % method )
365
365
366
366
367
- def sinkhorn_knopp (a , b , M , reg , numItermax = 1000 , stopThr = 1e-9 ,
367
+ def sinkhorn_knopp (a , b , M , reg , numItermax = 1000 , stopThr = 1e-9 , warmstart = None ,
368
368
verbose = False , log = False , warn = True ,
369
369
** kwargs ):
370
370
r"""
@@ -409,6 +409,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
409
409
Max number of iterations
410
410
stopThr : float, optional
411
411
Stop threshold on error (>0)
412
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
413
+ Initialization of dual vectors. If provided, the dual vectors must be in logarithm form,
414
+ i.e. warmstart = (log_u, log_v), but not (u, v).
412
415
verbose : bool, optional
413
416
Print information along iterations
414
417
log : bool, optional
@@ -474,12 +477,15 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
474
477
475
478
# we assume that no distances are null except those of the diagonal of
476
479
# distances
477
- if n_hists :
478
- u = nx .ones ((dim_a , n_hists ), type_as = M ) / dim_a
479
- v = nx .ones ((dim_b , n_hists ), type_as = M ) / dim_b
480
+ if warmstart is None :
481
+ if n_hists :
482
+ u = nx .ones ((dim_a , n_hists ), type_as = M ) / dim_a
483
+ v = nx .ones ((dim_b , n_hists ), type_as = M ) / dim_b
484
+ else :
485
+ u = nx .ones (dim_a , type_as = M ) / dim_a
486
+ v = nx .ones (dim_b , type_as = M ) / dim_b
480
487
else :
481
- u = nx .ones (dim_a , type_as = M ) / dim_a
482
- v = nx .ones (dim_b , type_as = M ) / dim_b
488
+ u , v = nx .exp (warmstart [0 ]), nx .exp (warmstart [1 ])
483
489
484
490
K = nx .exp (M / (- reg ))
485
491
@@ -546,7 +552,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
546
552
return u .reshape ((- 1 , 1 )) * K * v .reshape ((1 , - 1 ))
547
553
548
554
549
- def sinkhorn_log (a , b , M , reg , numItermax = 1000 , stopThr = 1e-9 , verbose = False ,
555
+ def sinkhorn_log (a , b , M , reg , numItermax = 1000 , stopThr = 1e-9 , warmstart = None , verbose = False ,
550
556
log = False , warn = True , ** kwargs ):
551
557
r"""
552
558
Solve the entropic regularization optimal transport problem in log space
@@ -590,6 +596,9 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
590
596
Max number of iterations
591
597
stopThr : float, optional
592
598
Stop threshold on error (>0)
599
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
600
+ Initialization of dual vectors. If provided, the dual vectors must be in logarithm form,
601
+ i.e. warmstart = (log_u, log_v), but not (u, v).
593
602
verbose : bool, optional
594
603
Print information along iterations
595
604
log : bool, optional
@@ -656,14 +665,18 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
656
665
else :
657
666
n_hists = 0
658
667
668
+ # in case of multiple historgrams
669
+ if n_hists > 1 and warmstart is None :
670
+ warmstart = [None ] * n_hists
671
+
659
672
if n_hists : # we do not want to use tensors sor we do a loop
660
673
661
674
lst_loss = []
662
675
lst_u = []
663
676
lst_v = []
664
677
665
678
for k in range (n_hists ):
666
- res = sinkhorn_log (a , b [:, k ], M , reg , numItermax = numItermax ,
679
+ res = sinkhorn_log (a , b [:, k ], M , reg , numItermax = numItermax , warmstart = warmstart [ k ],
667
680
stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
668
681
669
682
if log :
@@ -691,9 +704,11 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
691
704
692
705
# we assume that no distances are null except those of the diagonal of
693
706
# distances
694
-
695
- u = nx .zeros (dim_a , type_as = M )
696
- v = nx .zeros (dim_b , type_as = M )
707
+ if warmstart is None :
708
+ u = nx .zeros (dim_a , type_as = M )
709
+ v = nx .zeros (dim_b , type_as = M )
710
+ else :
711
+ u , v = warmstart
697
712
698
713
def get_logT (u , v ):
699
714
if n_hists :
0 commit comments