Skip to content

Commit eabeabe

Browse files
committed
Added argument for warmstart of dual vectors in Sinkhorn-based methods in
1 parent 7ba5f03 commit eabeabe

File tree

3 files changed

+154
-42
lines changed

3 files changed

+154
-42
lines changed

RELEASES.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
99
- Added Free Support Sinkhorn Barycenter + example (PR #387)
1010
- New API for OT solver using function `ot.solve` (PR #388)
11-
- Backend version of `ot.partial` and `ot.smooth` (PR #388)
12-
11+
- Backend version of `ot.partial` and `ot.smooth` (PR #388)
12+
- Added argument for warmstart of dual vectors in Sinkhorn-based methods in `ot.bregman` (PR #)
1313

1414
#### Closed issues
1515

ot/bregman.py

Lines changed: 76 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .backend import get_backend
2525

2626

27-
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
27+
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None,
2828
stopThr=1e-9, verbose=False, log=False, warn=True,
2929
**kwargs):
3030
r"""
@@ -93,6 +93,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
9393
those function for specific parameters
9494
numItermax : int, optional
9595
Max number of iterations
96+
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
97+
Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
98+
i.e. warmstart = (log_u, log_v), but not (u, v).
9699
stopThr : float, optional
97100
Stop threshold on error (>0)
98101
verbose : bool, optional
@@ -154,35 +157,35 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
154157
"""
155158

156159
if method.lower() == 'sinkhorn':
157-
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
160+
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart,
158161
stopThr=stopThr, verbose=verbose, log=log,
159162
warn=warn,
160163
**kwargs)
161164
elif method.lower() == 'sinkhorn_log':
162-
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
165+
return sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart,
163166
stopThr=stopThr, verbose=verbose, log=log,
164167
warn=warn,
165168
**kwargs)
166169
elif method.lower() == 'greenkhorn':
167-
return greenkhorn(a, b, M, reg, numItermax=numItermax,
170+
return greenkhorn(a, b, M, reg, numItermax=numItermax, warmstart=warmstart,
168171
stopThr=stopThr, verbose=verbose, log=log,
169172
warn=warn)
170173
elif method.lower() == 'sinkhorn_stabilized':
171-
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
174+
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart,
172175
stopThr=stopThr, verbose=verbose,
173176
log=log, warn=warn,
174177
**kwargs)
175178
elif method.lower() == 'sinkhorn_epsilon_scaling':
176179
return sinkhorn_epsilon_scaling(a, b, M, reg,
177-
numItermax=numItermax,
180+
numItermax=numItermax, warmstart=warmstart,
178181
stopThr=stopThr, verbose=verbose,
179182
log=log, warn=warn,
180183
**kwargs)
181184
else:
182185
raise ValueError("Unknown method '%s'." % method)
183186

184187

185-
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
188+
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, warmstart=None,
186189
stopThr=1e-9, verbose=False, log=False, warn=False, **kwargs):
187190
r"""
188191
Solve the entropic regularization optimal transport problem and return the loss
@@ -252,6 +255,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
252255
'sinkhorn_stabilized', see those function for specific parameters
253256
numItermax : int, optional
254257
Max number of iterations
258+
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
259+
Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
260+
i.e. warmstart = (log_u, log_v), but not (u, v).
255261
stopThr : float, optional
256262
Stop threshold on error (>0)
257263
verbose : bool, optional
@@ -322,17 +328,17 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
322328

323329
if len(b.shape) < 2:
324330
if method.lower() == 'sinkhorn':
325-
res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
331+
res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart,
326332
stopThr=stopThr, verbose=verbose,
327333
log=log, warn=warn,
328334
**kwargs)
329335
elif method.lower() == 'sinkhorn_log':
330-
res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
336+
res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart,
331337
stopThr=stopThr, verbose=verbose,
332338
log=log, warn=warn,
333339
**kwargs)
334340
elif method.lower() == 'sinkhorn_stabilized':
335-
res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
341+
res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart,
336342
stopThr=stopThr, verbose=verbose,
337343
log=log, warn=warn,
338344
**kwargs)
@@ -346,25 +352,25 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
346352
else:
347353

348354
if method.lower() == 'sinkhorn':
349-
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
355+
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, warmstart=warmstart,
350356
stopThr=stopThr, verbose=verbose,
351357
log=log, warn=warn,
352358
**kwargs)
353359
elif method.lower() == 'sinkhorn_log':
354-
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
360+
return sinkhorn_log(a, b, M, reg, numItermax=numItermax, warmstart=warmstart,
355361
stopThr=stopThr, verbose=verbose,
356362
log=log, warn=warn,
357363
**kwargs)
358364
elif method.lower() == 'sinkhorn_stabilized':
359-
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
365+
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, warmstart=warmstart,
360366
stopThr=stopThr, verbose=verbose,
361367
log=log, warn=warn,
362368
**kwargs)
363369
else:
364370
raise ValueError("Unknown method '%s'." % method)
365371

366372

367-
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None,
373+
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9,
368374
verbose=False, log=False, warn=True,
369375
**kwargs):
370376
r"""
@@ -407,11 +413,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None,
407413
Regularization term >0
408414
numItermax : int, optional
409415
Max number of iterations
410-
stopThr : float, optional
411-
Stop threshold on error (>0)
412416
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
413-
Initialization of dual vectors. If provided, the dual vectors must be in logarithm form,
417+
Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
414418
i.e. warmstart = (log_u, log_v), but not (u, v).
419+
stopThr : float, optional
420+
Stop threshold on error (>0)
415421
verbose : bool, optional
416422
Print information along iterations
417423
log : bool, optional
@@ -552,7 +558,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None,
552558
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
553559

554560

555-
def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None, verbose=False,
561+
def sinkhorn_log(a, b, M, reg, numItermax=1000, warmstart=None, stopThr=1e-9, verbose=False,
556562
log=False, warn=True, **kwargs):
557563
r"""
558564
Solve the entropic regularization optimal transport problem in log space
@@ -594,11 +600,11 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None, ve
594600
Regularization term >0
595601
numItermax : int, optional
596602
Max number of iterations
597-
stopThr : float, optional
598-
Stop threshold on error (>0)
599603
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
600-
Initialization of dual vectors. If provided, the dual vectors must be in logarithm form,
604+
Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
601605
i.e. warmstart = (log_u, log_v), but not (u, v).
606+
stopThr : float, optional
607+
Stop threshold on error (>0)
602608
verbose : bool, optional
603609
Print information along iterations
604610
log : bool, optional
@@ -761,7 +767,7 @@ def get_logT(u, v):
761767
return nx.exp(get_logT(u, v))
762768

763769

764-
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
770+
def greenkhorn(a, b, M, reg, numItermax=10000, warmstart=None, stopThr=1e-9, verbose=False,
765771
log=False, warn=True):
766772
r"""
767773
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -804,6 +810,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
804810
Regularization term >0
805811
numItermax : int, optional
806812
Max number of iterations
813+
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
814+
Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
815+
i.e. warmstart = (log_u, log_v), but not (u, v).
807816
stopThr : float, optional
808817
Stop threshold on error (>0)
809818
log : bool, optional
@@ -868,8 +877,11 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
868877

869878
K = nx.exp(-M / reg)
870879

871-
u = nx.full((dim_a,), 1. / dim_a, type_as=K)
872-
v = nx.full((dim_b,), 1. / dim_b, type_as=K)
880+
if warmstart is None:
881+
u = nx.full((dim_a,), 1. / dim_a, type_as=K)
882+
v = nx.full((dim_b,), 1. / dim_b, type_as=K)
883+
else:
884+
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
873885
G = u[:, None] * K * v[None, :]
874886

875887
viol = nx.sum(G, axis=1) - a
@@ -2872,7 +2884,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
28722884

28732885

28742886
def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
2875-
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
2887+
numIterMax=10000, warmstart=None, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
28762888
log=False, warn=True, **kwargs):
28772889
r'''
28782890
Solve the entropic regularization optimal transport problem and return the
@@ -2911,6 +2923,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
29112923
samples weights in the target domain
29122924
numItermax : int, optional
29132925
Max number of iterations
2926+
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
2927+
Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
2928+
i.e. warmstart = (log_u, log_v), but not (u, v).
29142929
stopThr : float, optional
29152930
Stop threshold on error (>0)
29162931
isLazy: boolean, optional
@@ -2976,7 +2991,10 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
29762991
dict_log = {"err": []}
29772992

29782993
log_a, log_b = nx.log(a), nx.log(b)
2979-
f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
2994+
if warmstart is None:
2995+
f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
2996+
else:
2997+
f, g = warmstart
29802998

29812999
if isinstance(batchSize, int):
29823000
bs, bt = batchSize, batchSize
@@ -3048,17 +3066,17 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
30483066
else:
30493067
M = dist(X_s, X_t, metric=metric)
30503068
if log:
3051-
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
3069+
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, stopThr=stopThr,
30523070
verbose=verbose, log=True, **kwargs)
30533071
return pi, log
30543072
else:
3055-
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
3073+
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart, stopThr=stopThr,
30563074
verbose=verbose, log=False, **kwargs)
30573075
return pi
30583076

30593077

30603078
def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
3061-
numIterMax=10000, stopThr=1e-9, isLazy=False,
3079+
numIterMax=10000, warmstart=None, stopThr=1e-9, isLazy=False,
30623080
batchSize=100, verbose=False, log=False, warn=True, **kwargs):
30633081
r'''
30643082
Solve the entropic regularization optimal transport problem from empirical
@@ -3101,6 +3119,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
31013119
samples weights in the target domain
31023120
numItermax : int, optional
31033121
Max number of iterations
3122+
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
3123+
Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
3124+
i.e. warmstart = (log_u, log_v), but not (u, v).
31043125
stopThr : float, optional
31053126
Stop threshold on error (>0)
31063127
isLazy: boolean, optional
@@ -3167,15 +3188,18 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
31673188
if isLazy:
31683189
if log:
31693190
f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
3170-
numIterMax=numIterMax,
3191+
numIterMax=numIterMax,
3192+
warmstart=warmstart,
31713193
stopThr=stopThr,
31723194
isLazy=isLazy,
31733195
batchSize=batchSize,
31743196
verbose=verbose, log=log,
31753197
warn=warn)
31763198
else:
31773199
f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
3178-
numIterMax=numIterMax, stopThr=stopThr,
3200+
numIterMax=numIterMax,
3201+
warmstart=warmstart,
3202+
stopThr=stopThr,
31793203
isLazy=isLazy, batchSize=batchSize,
31803204
verbose=verbose, log=log,
31813205
warn=warn)
@@ -3203,19 +3227,19 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
32033227
M = dist(X_s, X_t, metric=metric)
32043228

32053229
if log:
3206-
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
3230+
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart,
32073231
stopThr=stopThr, verbose=verbose, log=log,
32083232
warn=warn, **kwargs)
32093233
return sinkhorn_loss, log
32103234
else:
3211-
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
3235+
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, warmstart=warmstart,
32123236
stopThr=stopThr, verbose=verbose, log=log,
32133237
warn=warn, **kwargs)
32143238
return sinkhorn_loss
32153239

32163240

32173241
def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
3218-
numIterMax=10000, stopThr=1e-9,
3242+
numIterMax=10000, warmstart=None, stopThr=1e-9,
32193243
verbose=False, log=False, warn=True,
32203244
**kwargs):
32213245
r'''
@@ -3286,6 +3310,9 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
32863310
samples weights in the target domain
32873311
numItermax : int, optional
32883312
Max number of iterations
3313+
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
3314+
Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
3315+
i.e. warmstart = (log_u, log_v), but not (u, v).
32893316
stopThr : float, optional
32903317
Stop threshold on error (>0)
32913318
verbose : bool, optional
@@ -3323,20 +3350,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
33233350
X_s, X_t = list_to_array(X_s, X_t)
33243351

33253352
nx = get_backend(X_s, X_t)
3353+
if warmstart is None:
3354+
warmstart_a, warmstart_b = None, None
3355+
else:
3356+
u, v = warmstart
3357+
warmstart_a = (u, u)
3358+
warmstart_b = (v, v)
33263359

33273360
if log:
33283361
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
3329-
numIterMax=numIterMax,
3362+
numIterMax=numIterMax, warmstart=warmstart,
33303363
stopThr=stopThr, verbose=verbose,
33313364
log=log, warn=warn, **kwargs)
33323365

33333366
sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
3334-
numIterMax=numIterMax,
3367+
numIterMax=numIterMax, warmstart=warmstart_a,
33353368
stopThr=stopThr, verbose=verbose,
33363369
log=log, warn=warn, **kwargs)
33373370

33383371
sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
3339-
numIterMax=numIterMax,
3372+
numIterMax=numIterMax, warmstart=warmstart_b,
33403373
stopThr=stopThr, verbose=verbose,
33413374
log=log, warn=warn, **kwargs)
33423375

@@ -3354,17 +3387,20 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
33543387

33553388
else:
33563389
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
3357-
numIterMax=numIterMax, stopThr=stopThr,
3390+
numIterMax=numIterMax, warmstart=warmstart,
3391+
stopThr=stopThr,
33583392
verbose=verbose, log=log,
33593393
warn=warn, **kwargs)
33603394

33613395
sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
3362-
numIterMax=numIterMax, stopThr=stopThr,
3396+
numIterMax=numIterMax, warmstart=warmstart_a,
3397+
stopThr=stopThr,
33633398
verbose=verbose, log=log,
33643399
warn=warn, **kwargs)
33653400

33663401
sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
3367-
numIterMax=numIterMax, stopThr=stopThr,
3402+
numIterMax=numIterMax, warmstart=warmstart_b,
3403+
stopThr=stopThr,
33683404
verbose=verbose, log=log,
33693405
warn=warn, **kwargs)
33703406

0 commit comments

Comments
 (0)