24
24
from .backend import get_backend
25
25
26
26
27
- def sinkhorn (a , b , M , reg , method = 'sinkhorn' , numItermax = 1000 ,
27
+ def sinkhorn (a , b , M , reg , method = 'sinkhorn' , numItermax = 1000 , warmstart = None ,
28
28
stopThr = 1e-9 , verbose = False , log = False , warn = True ,
29
29
** kwargs ):
30
30
r"""
@@ -93,6 +93,9 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
93
93
those function for specific parameters
94
94
numItermax : int, optional
95
95
Max number of iterations
96
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
97
+ Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
98
+ i.e. warmstart = (log_u, log_v), but not (u, v).
96
99
stopThr : float, optional
97
100
Stop threshold on error (>0)
98
101
verbose : bool, optional
@@ -154,35 +157,35 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
154
157
"""
155
158
156
159
if method .lower () == 'sinkhorn' :
157
- return sinkhorn_knopp (a , b , M , reg , numItermax = numItermax ,
160
+ return sinkhorn_knopp (a , b , M , reg , numItermax = numItermax , warmstart = warmstart ,
158
161
stopThr = stopThr , verbose = verbose , log = log ,
159
162
warn = warn ,
160
163
** kwargs )
161
164
elif method .lower () == 'sinkhorn_log' :
162
- return sinkhorn_log (a , b , M , reg , numItermax = numItermax ,
165
+ return sinkhorn_log (a , b , M , reg , numItermax = numItermax , warmstart = warmstart ,
163
166
stopThr = stopThr , verbose = verbose , log = log ,
164
167
warn = warn ,
165
168
** kwargs )
166
169
elif method .lower () == 'greenkhorn' :
167
- return greenkhorn (a , b , M , reg , numItermax = numItermax ,
170
+ return greenkhorn (a , b , M , reg , numItermax = numItermax , warmstart = warmstart ,
168
171
stopThr = stopThr , verbose = verbose , log = log ,
169
172
warn = warn )
170
173
elif method .lower () == 'sinkhorn_stabilized' :
171
- return sinkhorn_stabilized (a , b , M , reg , numItermax = numItermax ,
174
+ return sinkhorn_stabilized (a , b , M , reg , numItermax = numItermax , warmstart = warmstart ,
172
175
stopThr = stopThr , verbose = verbose ,
173
176
log = log , warn = warn ,
174
177
** kwargs )
175
178
elif method .lower () == 'sinkhorn_epsilon_scaling' :
176
179
return sinkhorn_epsilon_scaling (a , b , M , reg ,
177
- numItermax = numItermax ,
180
+ numItermax = numItermax , warmstart = warmstart ,
178
181
stopThr = stopThr , verbose = verbose ,
179
182
log = log , warn = warn ,
180
183
** kwargs )
181
184
else :
182
185
raise ValueError ("Unknown method '%s'." % method )
183
186
184
187
185
- def sinkhorn2 (a , b , M , reg , method = 'sinkhorn' , numItermax = 1000 ,
188
+ def sinkhorn2 (a , b , M , reg , method = 'sinkhorn' , numItermax = 1000 , warmstart = None ,
186
189
stopThr = 1e-9 , verbose = False , log = False , warn = False , ** kwargs ):
187
190
r"""
188
191
Solve the entropic regularization optimal transport problem and return the loss
@@ -252,6 +255,9 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
252
255
'sinkhorn_stabilized', see those function for specific parameters
253
256
numItermax : int, optional
254
257
Max number of iterations
258
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
259
+ Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
260
+ i.e. warmstart = (log_u, log_v), but not (u, v).
255
261
stopThr : float, optional
256
262
Stop threshold on error (>0)
257
263
verbose : bool, optional
@@ -322,17 +328,17 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
322
328
323
329
if len (b .shape ) < 2 :
324
330
if method .lower () == 'sinkhorn' :
325
- res = sinkhorn_knopp (a , b , M , reg , numItermax = numItermax ,
331
+ res = sinkhorn_knopp (a , b , M , reg , numItermax = numItermax , warmstart = warmstart ,
326
332
stopThr = stopThr , verbose = verbose ,
327
333
log = log , warn = warn ,
328
334
** kwargs )
329
335
elif method .lower () == 'sinkhorn_log' :
330
- res = sinkhorn_log (a , b , M , reg , numItermax = numItermax ,
336
+ res = sinkhorn_log (a , b , M , reg , numItermax = numItermax , warmstart = warmstart ,
331
337
stopThr = stopThr , verbose = verbose ,
332
338
log = log , warn = warn ,
333
339
** kwargs )
334
340
elif method .lower () == 'sinkhorn_stabilized' :
335
- res = sinkhorn_stabilized (a , b , M , reg , numItermax = numItermax ,
341
+ res = sinkhorn_stabilized (a , b , M , reg , numItermax = numItermax , warmstart = warmstart ,
336
342
stopThr = stopThr , verbose = verbose ,
337
343
log = log , warn = warn ,
338
344
** kwargs )
@@ -346,25 +352,25 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
346
352
else :
347
353
348
354
if method .lower () == 'sinkhorn' :
349
- return sinkhorn_knopp (a , b , M , reg , numItermax = numItermax ,
355
+ return sinkhorn_knopp (a , b , M , reg , numItermax = numItermax , warmstart = warmstart ,
350
356
stopThr = stopThr , verbose = verbose ,
351
357
log = log , warn = warn ,
352
358
** kwargs )
353
359
elif method .lower () == 'sinkhorn_log' :
354
- return sinkhorn_log (a , b , M , reg , numItermax = numItermax ,
360
+ return sinkhorn_log (a , b , M , reg , numItermax = numItermax , warmstart = warmstart ,
355
361
stopThr = stopThr , verbose = verbose ,
356
362
log = log , warn = warn ,
357
363
** kwargs )
358
364
elif method .lower () == 'sinkhorn_stabilized' :
359
- return sinkhorn_stabilized (a , b , M , reg , numItermax = numItermax ,
365
+ return sinkhorn_stabilized (a , b , M , reg , numItermax = numItermax , warmstart = warmstart ,
360
366
stopThr = stopThr , verbose = verbose ,
361
367
log = log , warn = warn ,
362
368
** kwargs )
363
369
else :
364
370
raise ValueError ("Unknown method '%s'." % method )
365
371
366
372
367
- def sinkhorn_knopp (a , b , M , reg , numItermax = 1000 , stopThr = 1e-9 , warmstart = None ,
373
+ def sinkhorn_knopp (a , b , M , reg , numItermax = 1000 , warmstart = None , stopThr = 1e-9 ,
368
374
verbose = False , log = False , warn = True ,
369
375
** kwargs ):
370
376
r"""
@@ -407,11 +413,11 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None,
407
413
Regularization term >0
408
414
numItermax : int, optional
409
415
Max number of iterations
410
- stopThr : float, optional
411
- Stop threshold on error (>0)
412
416
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
413
- Initialization of dual vectors. If provided, the dual vectors must be in logarithm form ,
417
+ Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm ,
414
418
i.e. warmstart = (log_u, log_v), but not (u, v).
419
+ stopThr : float, optional
420
+ Stop threshold on error (>0)
415
421
verbose : bool, optional
416
422
Print information along iterations
417
423
log : bool, optional
@@ -552,7 +558,7 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None,
552
558
return u .reshape ((- 1 , 1 )) * K * v .reshape ((1 , - 1 ))
553
559
554
560
555
- def sinkhorn_log (a , b , M , reg , numItermax = 1000 , stopThr = 1e-9 , warmstart = None , verbose = False ,
561
+ def sinkhorn_log (a , b , M , reg , numItermax = 1000 , warmstart = None , stopThr = 1e-9 , verbose = False ,
556
562
log = False , warn = True , ** kwargs ):
557
563
r"""
558
564
Solve the entropic regularization optimal transport problem in log space
@@ -594,11 +600,11 @@ def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, warmstart=None, ve
594
600
Regularization term >0
595
601
numItermax : int, optional
596
602
Max number of iterations
597
- stopThr : float, optional
598
- Stop threshold on error (>0)
599
603
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
600
- Initialization of dual vectors. If provided, the dual vectors must be in logarithm form ,
604
+ Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm ,
601
605
i.e. warmstart = (log_u, log_v), but not (u, v).
606
+ stopThr : float, optional
607
+ Stop threshold on error (>0)
602
608
verbose : bool, optional
603
609
Print information along iterations
604
610
log : bool, optional
@@ -761,7 +767,7 @@ def get_logT(u, v):
761
767
return nx .exp (get_logT (u , v ))
762
768
763
769
764
- def greenkhorn (a , b , M , reg , numItermax = 10000 , stopThr = 1e-9 , verbose = False ,
770
+ def greenkhorn (a , b , M , reg , numItermax = 10000 , warmstart = None , stopThr = 1e-9 , verbose = False ,
765
771
log = False , warn = True ):
766
772
r"""
767
773
Solve the entropic regularization optimal transport problem and return the OT matrix
@@ -804,6 +810,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
804
810
Regularization term >0
805
811
numItermax : int, optional
806
812
Max number of iterations
813
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
814
+ Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
815
+ i.e. warmstart = (log_u, log_v), but not (u, v).
807
816
stopThr : float, optional
808
817
Stop threshold on error (>0)
809
818
log : bool, optional
@@ -868,8 +877,11 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
868
877
869
878
K = nx .exp (- M / reg )
870
879
871
- u = nx .full ((dim_a ,), 1. / dim_a , type_as = K )
872
- v = nx .full ((dim_b ,), 1. / dim_b , type_as = K )
880
+ if warmstart is None :
881
+ u = nx .full ((dim_a ,), 1. / dim_a , type_as = K )
882
+ v = nx .full ((dim_b ,), 1. / dim_b , type_as = K )
883
+ else :
884
+ u , v = nx .exp (warmstart [0 ]), nx .exp (warmstart [1 ])
873
885
G = u [:, None ] * K * v [None , :]
874
886
875
887
viol = nx .sum (G , axis = 1 ) - a
@@ -2872,7 +2884,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
2872
2884
2873
2885
2874
2886
def empirical_sinkhorn (X_s , X_t , reg , a = None , b = None , metric = 'sqeuclidean' ,
2875
- numIterMax = 10000 , stopThr = 1e-9 , isLazy = False , batchSize = 100 , verbose = False ,
2887
+ numIterMax = 10000 , warmstart = None , stopThr = 1e-9 , isLazy = False , batchSize = 100 , verbose = False ,
2876
2888
log = False , warn = True , ** kwargs ):
2877
2889
r'''
2878
2890
Solve the entropic regularization optimal transport problem and return the
@@ -2911,6 +2923,9 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
2911
2923
samples weights in the target domain
2912
2924
numItermax : int, optional
2913
2925
Max number of iterations
2926
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
2927
+ Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
2928
+ i.e. warmstart = (log_u, log_v), but not (u, v).
2914
2929
stopThr : float, optional
2915
2930
Stop threshold on error (>0)
2916
2931
isLazy: boolean, optional
@@ -2976,7 +2991,10 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
2976
2991
dict_log = {"err" : []}
2977
2992
2978
2993
log_a , log_b = nx .log (a ), nx .log (b )
2979
- f , g = nx .zeros ((ns ,), type_as = a ), nx .zeros ((nt ,), type_as = a )
2994
+ if warmstart is None :
2995
+ f , g = nx .zeros ((ns ,), type_as = a ), nx .zeros ((nt ,), type_as = a )
2996
+ else :
2997
+ f , g = warmstart
2980
2998
2981
2999
if isinstance (batchSize , int ):
2982
3000
bs , bt = batchSize , batchSize
@@ -3048,17 +3066,17 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
3048
3066
else :
3049
3067
M = dist (X_s , X_t , metric = metric )
3050
3068
if log :
3051
- pi , log = sinkhorn (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr ,
3069
+ pi , log = sinkhorn (a , b , M , reg , numItermax = numIterMax , warmstart = warmstart , stopThr = stopThr ,
3052
3070
verbose = verbose , log = True , ** kwargs )
3053
3071
return pi , log
3054
3072
else :
3055
- pi = sinkhorn (a , b , M , reg , numItermax = numIterMax , stopThr = stopThr ,
3073
+ pi = sinkhorn (a , b , M , reg , numItermax = numIterMax , warmstart = warmstart , stopThr = stopThr ,
3056
3074
verbose = verbose , log = False , ** kwargs )
3057
3075
return pi
3058
3076
3059
3077
3060
3078
def empirical_sinkhorn2 (X_s , X_t , reg , a = None , b = None , metric = 'sqeuclidean' ,
3061
- numIterMax = 10000 , stopThr = 1e-9 , isLazy = False ,
3079
+ numIterMax = 10000 , warmstart = None , stopThr = 1e-9 , isLazy = False ,
3062
3080
batchSize = 100 , verbose = False , log = False , warn = True , ** kwargs ):
3063
3081
r'''
3064
3082
Solve the entropic regularization optimal transport problem from empirical
@@ -3101,6 +3119,9 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
3101
3119
samples weights in the target domain
3102
3120
numItermax : int, optional
3103
3121
Max number of iterations
3122
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
3123
+ Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
3124
+ i.e. warmstart = (log_u, log_v), but not (u, v).
3104
3125
stopThr : float, optional
3105
3126
Stop threshold on error (>0)
3106
3127
isLazy: boolean, optional
@@ -3167,15 +3188,18 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
3167
3188
if isLazy :
3168
3189
if log :
3169
3190
f , g , dict_log = empirical_sinkhorn (X_s , X_t , reg , a , b , metric ,
3170
- numIterMax = numIterMax ,
3191
+ numIterMax = numIterMax ,
3192
+ warmstart = warmstart ,
3171
3193
stopThr = stopThr ,
3172
3194
isLazy = isLazy ,
3173
3195
batchSize = batchSize ,
3174
3196
verbose = verbose , log = log ,
3175
3197
warn = warn )
3176
3198
else :
3177
3199
f , g = empirical_sinkhorn (X_s , X_t , reg , a , b , metric ,
3178
- numIterMax = numIterMax , stopThr = stopThr ,
3200
+ numIterMax = numIterMax ,
3201
+ warmstart = warmstart ,
3202
+ stopThr = stopThr ,
3179
3203
isLazy = isLazy , batchSize = batchSize ,
3180
3204
verbose = verbose , log = log ,
3181
3205
warn = warn )
@@ -3203,19 +3227,19 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
3203
3227
M = dist (X_s , X_t , metric = metric )
3204
3228
3205
3229
if log :
3206
- sinkhorn_loss , log = sinkhorn2 (a , b , M , reg , numItermax = numIterMax ,
3230
+ sinkhorn_loss , log = sinkhorn2 (a , b , M , reg , numItermax = numIterMax , warmstart = warmstart ,
3207
3231
stopThr = stopThr , verbose = verbose , log = log ,
3208
3232
warn = warn , ** kwargs )
3209
3233
return sinkhorn_loss , log
3210
3234
else :
3211
- sinkhorn_loss = sinkhorn2 (a , b , M , reg , numItermax = numIterMax ,
3235
+ sinkhorn_loss = sinkhorn2 (a , b , M , reg , numItermax = numIterMax , warmstart = warmstart ,
3212
3236
stopThr = stopThr , verbose = verbose , log = log ,
3213
3237
warn = warn , ** kwargs )
3214
3238
return sinkhorn_loss
3215
3239
3216
3240
3217
3241
def empirical_sinkhorn_divergence (X_s , X_t , reg , a = None , b = None , metric = 'sqeuclidean' ,
3218
- numIterMax = 10000 , stopThr = 1e-9 ,
3242
+ numIterMax = 10000 , warmstart = None , stopThr = 1e-9 ,
3219
3243
verbose = False , log = False , warn = True ,
3220
3244
** kwargs ):
3221
3245
r'''
@@ -3286,6 +3310,9 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
3286
3310
samples weights in the target domain
3287
3311
numItermax : int, optional
3288
3312
Max number of iterations
3313
+ warmstart: tuple of arrays, shape (dim_a, dim_b), optional
3314
+ Initialization of dual vectors. If provided, the dual vectors must be already taken the logarithm,
3315
+ i.e. warmstart = (log_u, log_v), but not (u, v).
3289
3316
stopThr : float, optional
3290
3317
Stop threshold on error (>0)
3291
3318
verbose : bool, optional
@@ -3323,20 +3350,26 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
3323
3350
X_s , X_t = list_to_array (X_s , X_t )
3324
3351
3325
3352
nx = get_backend (X_s , X_t )
3353
+ if warmstart is None :
3354
+ warmstart_a , warmstart_b = None , None
3355
+ else :
3356
+ u , v = warmstart
3357
+ warmstart_a = (u , u )
3358
+ warmstart_b = (v , v )
3326
3359
3327
3360
if log :
3328
3361
sinkhorn_loss_ab , log_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric ,
3329
- numIterMax = numIterMax ,
3362
+ numIterMax = numIterMax , warmstart = warmstart ,
3330
3363
stopThr = stopThr , verbose = verbose ,
3331
3364
log = log , warn = warn , ** kwargs )
3332
3365
3333
3366
sinkhorn_loss_a , log_a = empirical_sinkhorn2 (X_s , X_s , reg , a , a , metric = metric ,
3334
- numIterMax = numIterMax ,
3367
+ numIterMax = numIterMax , warmstart = warmstart_a ,
3335
3368
stopThr = stopThr , verbose = verbose ,
3336
3369
log = log , warn = warn , ** kwargs )
3337
3370
3338
3371
sinkhorn_loss_b , log_b = empirical_sinkhorn2 (X_t , X_t , reg , b , b , metric = metric ,
3339
- numIterMax = numIterMax ,
3372
+ numIterMax = numIterMax , warmstart = warmstart_b ,
3340
3373
stopThr = stopThr , verbose = verbose ,
3341
3374
log = log , warn = warn , ** kwargs )
3342
3375
@@ -3354,17 +3387,20 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
3354
3387
3355
3388
else :
3356
3389
sinkhorn_loss_ab = empirical_sinkhorn2 (X_s , X_t , reg , a , b , metric = metric ,
3357
- numIterMax = numIterMax , stopThr = stopThr ,
3390
+ numIterMax = numIterMax , warmstart = warmstart ,
3391
+ stopThr = stopThr ,
3358
3392
verbose = verbose , log = log ,
3359
3393
warn = warn , ** kwargs )
3360
3394
3361
3395
sinkhorn_loss_a = empirical_sinkhorn2 (X_s , X_s , reg , a , a , metric = metric ,
3362
- numIterMax = numIterMax , stopThr = stopThr ,
3396
+ numIterMax = numIterMax , warmstart = warmstart_a ,
3397
+ stopThr = stopThr ,
3363
3398
verbose = verbose , log = log ,
3364
3399
warn = warn , ** kwargs )
3365
3400
3366
3401
sinkhorn_loss_b = empirical_sinkhorn2 (X_t , X_t , reg , b , b , metric = metric ,
3367
- numIterMax = numIterMax , stopThr = stopThr ,
3402
+ numIterMax = numIterMax , warmstart = warmstart_b ,
3403
+ stopThr = stopThr ,
3368
3404
verbose = verbose , log = log ,
3369
3405
warn = warn , ** kwargs )
3370
3406
0 commit comments