@@ -435,18 +435,23 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
435
435
##############################################################################
436
436
437
437
438
- def batch_grad_dual_alpha (M , reg , alpha , beta , batch_size , batch_alpha ,
439
- batch_beta ):
438
+ def batch_grad_dual (M , reg , a , b , alpha , beta , batch_size , batch_alpha ,
439
+ batch_beta ):
440
440
'''
441
441
Computes the partial gradient of F_\W_varepsilon
442
442
443
443
Compute the partial gradient of the dual problem:
444
444
445
445
..math:
446
446
\f orall i in batch_alpha,
447
- grad_alpha_i = 1 * batch_size -
448
- sum_{j in batch_beta} exp((alpha_i + beta_j - M_{i,j})/reg)
449
-
447
+ grad_alpha_i = alpha_i * batch_size/len(beta) -
448
+ sum_{j in batch_beta} exp((alpha_i + beta_j - M_{i,j})/reg)
449
+ * a_i * b_j
450
+
451
+ \f orall j in batch_alpha,
452
+ grad_beta_j = beta_j * batch_size/len(alpha) -
453
+ sum_{j in batch_alpha} exp((alpha_i + beta_j - M_{i,j})/reg)
454
+ * a_i * b_j
450
455
where :
451
456
- M is the (ns,nt) metric cost matrix
452
457
- alpha, beta are dual variables in R^ixR^J
@@ -478,7 +483,7 @@ def batch_grad_dual_alpha(M, reg, alpha, beta, batch_size, batch_alpha,
478
483
-------
479
484
480
485
grad : np.ndarray(ns,)
481
- partial grad F in alpha
486
+ partial grad F
482
487
483
488
Examples
484
489
--------
@@ -510,100 +515,20 @@ def batch_grad_dual_alpha(M, reg, alpha, beta, batch_size, batch_alpha,
510
515
arXiv preprint arxiv:1711.02283.
511
516
'''
512
517
513
- grad_alpha = np .zeros (batch_size )
514
- grad_alpha [:] = batch_size
515
- for j in batch_beta :
516
- grad_alpha -= np .exp ((alpha [batch_alpha ] + beta [j ] -
517
- M [batch_alpha , j ]) / reg )
518
- return grad_alpha
519
-
520
-
521
- def batch_grad_dual_beta (M , reg , alpha , beta , batch_size , batch_alpha ,
522
- batch_beta ):
523
- '''
524
- Computes the partial gradient of F_\W_varepsilon
525
-
526
- Compute the partial gradient of the dual problem:
527
-
528
- ..math:
529
- \f orall j in batch_beta,
530
- grad_beta_j = 1 * batch_size -
531
- sum_{i in batch_alpha} exp((alpha_i + beta_j - M_{i,j})/reg)
532
-
533
- where :
534
- - M is the (ns,nt) metric cost matrix
535
- - alpha, beta are dual variables in R^ixR^J
536
- - reg is the regularization term
537
- - batch_alpha and batch_beta are list of index
538
-
539
- The algorithm used for solving the dual problem is the SGD algorithm
540
- as proposed in [19]_ [alg.1]
541
-
542
- Parameters
543
- ----------
544
-
545
- M : np.ndarray(ns, nt),
546
- cost matrix
547
- reg : float number,
548
- Regularization term > 0
549
- alpha : np.ndarray(ns,)
550
- dual variable
551
- beta : np.ndarray(nt,)
552
- dual variable
553
- batch_size : int number
554
- size of the batch
555
- batch_alpha : np.ndarray(bs,)
556
- batch of index of alpha
557
- batch_beta : np.ndarray(bs,)
558
- batch of index of beta
559
-
560
- Returns
561
- -------
562
-
563
- grad : np.ndarray(ns,)
564
- partial grad F in beta
565
-
566
- Examples
567
- --------
568
-
569
- >>> n_source = 7
570
- >>> n_target = 4
571
- >>> reg = 1
572
- >>> numItermax = 20000
573
- >>> lr = 0.1
574
- >>> batch_size = 3
575
- >>> log = True
576
- >>> a = ot.utils.unif(n_source)
577
- >>> b = ot.utils.unif(n_target)
578
- >>> rng = np.random.RandomState(0)
579
- >>> X_source = rng.randn(n_source, 2)
580
- >>> Y_target = rng.randn(n_target, 2)
581
- >>> M = ot.dist(X_source, Y_target)
582
- >>> sgd_dual_pi, log = stochastic.solve_dual_entropic(a, b, M, reg,
583
- batch_size,
584
- numItermax, lr, log)
585
- >>> print(log['alpha'], log['beta'])
586
- >>> print(sgd_dual_pi)
587
-
588
- References
589
- ----------
590
-
591
- [Seguy et al., 2018] :
592
- International Conference on Learning Representation (2018),
593
- arXiv preprint arxiv:1711.02283.
518
+ G = - (np .exp ((alpha [batch_alpha , None ] + beta [None , batch_beta ] -
519
+ M [batch_alpha , :][:, batch_beta ]) / reg ) * a [batch_alpha , None ] *
520
+ b [None , batch_beta ])
521
+ grad_beta = np .zeros (np .shape (M )[1 ])
522
+ grad_alpha = np .zeros (np .shape (M )[0 ])
523
+ grad_beta [batch_beta ] = (b [batch_beta ] * len (batch_alpha ) / np .shape (M )[0 ] +
524
+ G .sum (0 ))
525
+ grad_alpha [batch_alpha ] = (a [batch_alpha ] * len (batch_beta ) /
526
+ np .shape (M )[1 ] + G .sum (1 ))
594
527
595
- '''
596
-
597
- grad_beta = np .zeros (batch_size )
598
- grad_beta [:] = batch_size
599
- for i in batch_alpha :
600
- grad_beta -= np .exp ((alpha [i ] +
601
- beta [batch_beta ] - M [i , batch_beta ]) / reg )
602
- return grad_beta
528
+ return grad_alpha , grad_beta
603
529
604
530
605
- def sgd_entropic_regularization (M , reg , batch_size , numItermax , lr ,
606
- alternate = True ):
531
+ def sgd_entropic_regularization (M , reg , a , b , batch_size , numItermax , lr ):
607
532
'''
608
533
Compute the sgd algorithm to solve the regularized discrete measures
609
534
optimal transport dual problem
@@ -628,6 +553,10 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
628
553
cost matrix
629
554
reg : float number,
630
555
Regularization term > 0
556
+ alpha : np.ndarray(ns,)
557
+ dual variable
558
+ beta : np.ndarray(nt,)
559
+ dual variable
631
560
batch_size : int number
632
561
size of the batch
633
562
numItermax : int number
@@ -677,35 +606,17 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
677
606
678
607
n_source = np .shape (M )[0 ]
679
608
n_target = np .shape (M )[1 ]
680
- cur_alpha = np .random .randn (n_source )
681
- cur_beta = np .random .randn (n_target )
682
- if alternate :
683
- for cur_iter in range (numItermax ):
684
- k = np .sqrt (cur_iter + 1 )
685
- batch_alpha = np .random .choice (n_source , batch_size , replace = False )
686
- batch_beta = np .random .choice (n_target , batch_size , replace = False )
687
- grad_F_alpha = batch_grad_dual_alpha (M , reg , cur_alpha , cur_beta ,
688
- batch_size , batch_alpha ,
689
- batch_beta )
690
- cur_alpha [batch_alpha ] += (lr / k ) * grad_F_alpha
691
- grad_F_beta = batch_grad_dual_beta (M , reg , cur_alpha , cur_beta ,
692
- batch_size , batch_alpha ,
693
- batch_beta )
694
- cur_beta [batch_beta ] += (lr / k ) * grad_F_beta
695
-
696
- else :
697
- for cur_iter in range (numItermax ):
698
- k = np .sqrt (cur_iter + 1 )
699
- batch_alpha = np .random .choice (n_source , batch_size , replace = False )
700
- batch_beta = np .random .choice (n_target , batch_size , replace = False )
701
- grad_F_alpha = batch_grad_dual_alpha (M , reg , cur_alpha , cur_beta ,
702
- batch_size , batch_alpha ,
703
- batch_beta )
704
- grad_F_beta = batch_grad_dual_beta (M , reg , cur_alpha , cur_beta ,
705
- batch_size , batch_alpha ,
706
- batch_beta )
707
- cur_alpha [batch_alpha ] += (lr / k ) * grad_F_alpha
708
- cur_beta [batch_beta ] += (lr / k ) * grad_F_beta
609
+ cur_alpha = np .zeros (n_source )
610
+ cur_beta = np .zeros (n_target )
611
+ for cur_iter in range (numItermax ):
612
+ k = np .sqrt (cur_iter / 100 + 1 )
613
+ batch_alpha = np .random .choice (n_source , batch_size , replace = False )
614
+ batch_beta = np .random .choice (n_target , batch_size , replace = False )
615
+ update_alpha , update_beta = batch_grad_dual (M , reg , a , b , cur_alpha ,
616
+ cur_beta , batch_size ,
617
+ batch_alpha , batch_beta )
618
+ cur_alpha += (lr / k ) * update_alpha
619
+ cur_beta += (lr / k ) * update_beta
709
620
710
621
return cur_alpha , cur_beta
711
622
@@ -787,7 +698,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
787
698
arXiv preprint arxiv:1711.02283.
788
699
'''
789
700
790
- opt_alpha , opt_beta = sgd_entropic_regularization (M , reg , batch_size ,
701
+ opt_alpha , opt_beta = sgd_entropic_regularization (M , reg , a , b , batch_size ,
791
702
numItermax , lr )
792
703
pi = (np .exp ((opt_alpha [:, None ] + opt_beta [None , :] - M [:, :]) / reg ) *
793
704
a [:, None ] * b [None , :])
0 commit comments