14
14
#
15
15
# License: MIT License
16
16
17
- import numpy as np
17
+ import math
18
18
import warnings
19
- from .utils import unif , dist
19
+
20
+ import numpy as np
20
21
from scipy .optimize import fmin_l_bfgs_b
22
+ from scipy .special import logsumexp
23
+
24
+ from .utils import unif , dist
25
+
26
+
27
+ def log_matvec (matrix , u , out ):
28
+ max_matrix = np .max (matrix )
29
+ max_u = np .max (u )
30
+ np .dot (np .exp (matrix - max_matrix ), np .exp (u - max_u ), out = out )
31
+ np .log (out , out = out )
32
+ out += max_matrix + max_u
21
33
22
34
23
35
def sinkhorn (a , b , M , reg , method = 'sinkhorn' , numItermax = 1000 ,
@@ -311,61 +323,68 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
311
323
ot.optim.cg : General regularized OT
312
324
313
325
"""
314
-
315
326
a = np .asarray (a , dtype = np .float64 )
316
327
b = np .asarray (b , dtype = np .float64 )
328
+
317
329
M = np .asarray (M , dtype = np .float64 )
318
330
319
331
if len (a ) == 0 :
320
- a = np .ones ((M .shape [0 ],), dtype = np .float64 ) / M .shape [0 ]
332
+ a = np .ones ((M .shape [0 ], 1 ), dtype = np .float64 ) / M .shape [0 ]
321
333
if len (b ) == 0 :
322
- b = np .ones ((M .shape [1 ],), dtype = np .float64 ) / M .shape [1 ]
323
-
324
- # init data
325
- dim_a = len (a )
326
- dim_b = len (b )
334
+ b = np .ones ((M .shape [1 ], 1 ), dtype = np .float64 ) / M .shape [1 ]
327
335
328
336
if len (b .shape ) > 1 :
329
337
n_hists = b .shape [1 ]
330
338
else :
331
339
n_hists = 0
332
340
341
+ if len (a .shape ) == 1 :
342
+ a = a [:, None ]
343
+
344
+ if len (b .shape ) == 1 :
345
+ b = b [:, None ]
346
+
347
+ log_threshold = math .log (stopThr )
348
+ is_logweight = kwargs .get ('is_logweight' , False )
349
+
350
+ if not is_logweight :
351
+ a = np .log (a )
352
+ b = np .log (b )
353
+
354
+ # init data
355
+ dim_a = len (a )
356
+ dim_b = len (b )
357
+
333
358
if log :
334
359
log = {'err' : []}
335
360
336
361
# we assume that no distances are null except those of the diagonal of
337
362
# distances
338
363
if n_hists :
339
- u = np .ones ((dim_a , n_hists )) / dim_a
340
- v = np .ones ((dim_b , n_hists )) / dim_b
364
+ u = np .zeros ((dim_a , n_hists )) - math . log ( dim_a )
365
+ v = np .zeros ((dim_b , n_hists )) - math . log ( dim_b )
341
366
else :
342
- u = np .ones (dim_a ) / dim_a
343
- v = np .ones (dim_b ) / dim_b
344
-
345
- # print(reg)
346
-
347
- # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
348
- K = np .empty (M .shape , dtype = M .dtype )
349
- np .divide (M , - reg , out = K )
350
- np .exp (K , out = K )
367
+ u = np .zeros ((dim_a , 1 )) - math .log (dim_a )
368
+ v = np .zeros ((dim_b , 1 )) - math .log (dim_b )
351
369
352
- # print(np.min(K))
353
- tmp2 = np .empty (b .shape , dtype = M .dtype )
370
+ log_K = - M / reg
354
371
355
- Kp = (1 / a ).reshape (- 1 , 1 ) * K
372
+ log_Kp = - a .reshape (- 1 , 1 ) + log_K
373
+ log_K_T = log_K .T
356
374
cpt = 0
357
- err = 1
358
- while (err > stopThr and cpt < numItermax ):
375
+ log_err = 0.5 * log_threshold
376
+
377
+ while log_err > log_threshold and cpt < numItermax :
359
378
uprev = u
360
379
vprev = v
361
380
362
- KtransposeU = np .dot (K .T , u )
363
- v = np .divide (b , KtransposeU )
364
- u = 1. / np .dot (Kp , v )
381
+ log_matvec (log_K_T , u , v )
382
+ v *= - 1
383
+ v += b
384
+ log_matvec (log_Kp , v , u )
385
+ u *= - 1
365
386
366
- if (np .any (KtransposeU == 0 )
367
- or np .any (np .isnan (u )) or np .any (np .isnan (v ))
368
- or np .any (np .isinf (u )) or np .any (np .isinf (v ))):
387
+ if np .any (~ np .isfinite (u )) or np .any (~ np .isfinite (v )):
369
388
# we have reached the machine precision
370
389
# come back to previous solution and quit loop
371
390
print ('Warning: numerical errors at iteration' , cpt )
@@ -375,27 +394,32 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
375
394
if cpt % 10 == 0 :
376
395
# we can speed up the process by checking for the error only all
377
396
# the 10th iterations
378
- if n_hists :
379
- np .einsum ('ik,ij,jk->jk' , u , K , v , out = tmp2 )
380
- else :
381
- # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
382
- np .einsum ('i,ij,j->j' , u , K , v , out = tmp2 )
383
- err = np .linalg .norm (tmp2 - b ) # violation of marginal
397
+ temp2 = u + log_K + v .T
398
+ temp2 = logsumexp (temp2 , axis = 0 , keepdims = True ).T
399
+ # noinspection PyTypeChecker
400
+ log_err = 0.5 * np .sum (np .exp (2 * temp2 ) - np .exp (2 * b )) # violation of marginal
401
+ # would be more efficient with a check on stability of dual vectors
384
402
if log :
385
- log ['err' ].append (err )
403
+ log ['err' ].append (math . exp ( log_err ) )
386
404
387
405
if verbose :
388
406
if cpt % 200 == 0 :
389
407
print (
390
408
'{:5s}|{:12s}' .format ('It.' , 'Err' ) + '\n ' + '-' * 19 )
391
- print ('{:5d}|{:8e}|' .format (cpt , err ))
409
+ print ('{:5d}|{:8e}|' .format (cpt , np . exp ( log_err ) ))
392
410
cpt = cpt + 1
393
411
if log :
394
- log ['u' ] = u
395
- log ['v' ] = v
396
-
412
+ log ['u' ] = np .exp (u ) if not is_logweight else u
413
+ log ['v' ] = np .exp (v ) if not is_logweight else v
414
+
415
+ gamma = u + log_K + v .T
416
+ res = logsumexp (gamma , axis = (0 , 1 ), b = M )
417
+ if not is_logweight :
418
+ gamma = np .exp (gamma )
419
+ res = np .exp (res )
420
+ if log :
421
+ log ['cost' ] = res
397
422
if n_hists : # return only loss
398
- res = np .einsum ('ik,ij,jk,ij->k' , u , K , v , M )
399
423
if log :
400
424
return res , log
401
425
else :
@@ -404,9 +428,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
404
428
else : # return OT matrix
405
429
406
430
if log :
407
- return u . reshape (( - 1 , 1 )) * K * v . reshape (( 1 , - 1 ) ), log
431
+ return gamma . squeeze ( ), log
408
432
else :
409
- return u . reshape (( - 1 , 1 )) * K * v . reshape (( 1 , - 1 ) )
433
+ return gamma . squeeze ( )
410
434
411
435
412
436
def greenkhorn (a , b , M , reg , numItermax = 10000 , stopThr = 1e-9 , verbose = False ,
@@ -716,7 +740,7 @@ def get_Gamma(alpha, beta, u, v):
716
740
if np .abs (u ).max () > tau or np .abs (v ).max () > tau :
717
741
if n_hists :
718
742
alpha , beta = alpha + reg * \
719
- np .max (np .log (u ), 1 ), beta + reg * np .max (np .log (v ))
743
+ np .max (np .log (u ), 1 ), beta + reg * np .max (np .log (v ))
720
744
else :
721
745
alpha , beta = alpha + reg * np .log (u ), beta + reg * np .log (v )
722
746
if n_hists :
@@ -2182,11 +2206,11 @@ def projection(u, epsilon):
2182
2206
2183
2207
# box constraints in L-BFGS-B (see Proposition 1 in [26])
2184
2208
bounds_u = [(max (a_I_min / ((nt - nt_budget ) * epsilon + nt_budget * (b_J_max / (
2185
- ns * epsilon * kappa * K_min ))), epsilon / kappa ), a_I_max / (nt * epsilon * K_min ))] * ns_budget
2209
+ ns * epsilon * kappa * K_min ))), epsilon / kappa ), a_I_max / (nt * epsilon * K_min ))] * ns_budget
2186
2210
2187
2211
bounds_v = [(
2188
- max (b_J_min / ((ns - ns_budget ) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min ))),
2189
- epsilon * kappa ), b_J_max / (ns * epsilon * K_min ))] * nt_budget
2212
+ max (b_J_min / ((ns - ns_budget ) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min ))),
2213
+ epsilon * kappa ), b_J_max / (ns * epsilon * K_min ))] * nt_budget
2190
2214
2191
2215
# pre-calculated constants for the objective
2192
2216
vec_eps_IJc = epsilon * kappa * (K_IJc * np .ones (nt - nt_budget ).reshape ((1 , - 1 ))).sum (axis = 1 )
0 commit comments