Skip to content

Commit 77b6890

Browse files
author
Kilian Fatras
committed
fixed bug in sgd dual
1 parent 208ff46 commit 77b6890

File tree

1 file changed

+38
-127
lines changed

1 file changed

+38
-127
lines changed

ot/stochastic.py

Lines changed: 38 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -435,18 +435,23 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
435435
##############################################################################
436436

437437

438-
def batch_grad_dual_alpha(M, reg, alpha, beta, batch_size, batch_alpha,
439-
batch_beta):
438+
def batch_grad_dual(M, reg, a, b, alpha, beta, batch_size, batch_alpha,
439+
batch_beta):
440440
'''
441441
Computes the partial gradient of F_\W_varepsilon
442442
443443
Compute the partial gradient of the dual problem:
444444
445445
..math:
446446
\forall i in batch_alpha,
447-
grad_alpha_i = 1 * batch_size -
448-
sum_{j in batch_beta} exp((alpha_i + beta_j - M_{i,j})/reg)
449-
447+
grad_alpha_i = alpha_i * batch_size/len(beta) -
448+
sum_{j in batch_beta} exp((alpha_i + beta_j - M_{i,j})/reg)
449+
* a_i * b_j
450+
451+
\forall j in batch_alpha,
452+
grad_beta_j = beta_j * batch_size/len(alpha) -
453+
sum_{j in batch_alpha} exp((alpha_i + beta_j - M_{i,j})/reg)
454+
* a_i * b_j
450455
where :
451456
- M is the (ns,nt) metric cost matrix
452457
- alpha, beta are dual variables in R^ixR^J
@@ -478,7 +483,7 @@ def batch_grad_dual_alpha(M, reg, alpha, beta, batch_size, batch_alpha,
478483
-------
479484
480485
grad : np.ndarray(ns,)
481-
partial grad F in alpha
486+
partial grad F
482487
483488
Examples
484489
--------
@@ -510,100 +515,20 @@ def batch_grad_dual_alpha(M, reg, alpha, beta, batch_size, batch_alpha,
510515
arXiv preprint arxiv:1711.02283.
511516
'''
512517

513-
grad_alpha = np.zeros(batch_size)
514-
grad_alpha[:] = batch_size
515-
for j in batch_beta:
516-
grad_alpha -= np.exp((alpha[batch_alpha] + beta[j] -
517-
M[batch_alpha, j]) / reg)
518-
return grad_alpha
519-
520-
521-
def batch_grad_dual_beta(M, reg, alpha, beta, batch_size, batch_alpha,
522-
batch_beta):
523-
'''
524-
Computes the partial gradient of F_\W_varepsilon
525-
526-
Compute the partial gradient of the dual problem:
527-
528-
..math:
529-
\forall j in batch_beta,
530-
grad_beta_j = 1 * batch_size -
531-
sum_{i in batch_alpha} exp((alpha_i + beta_j - M_{i,j})/reg)
532-
533-
where :
534-
- M is the (ns,nt) metric cost matrix
535-
- alpha, beta are dual variables in R^ixR^J
536-
- reg is the regularization term
537-
- batch_alpha and batch_beta are list of index
538-
539-
The algorithm used for solving the dual problem is the SGD algorithm
540-
as proposed in [19]_ [alg.1]
541-
542-
Parameters
543-
----------
544-
545-
M : np.ndarray(ns, nt),
546-
cost matrix
547-
reg : float number,
548-
Regularization term > 0
549-
alpha : np.ndarray(ns,)
550-
dual variable
551-
beta : np.ndarray(nt,)
552-
dual variable
553-
batch_size : int number
554-
size of the batch
555-
batch_alpha : np.ndarray(bs,)
556-
batch of index of alpha
557-
batch_beta : np.ndarray(bs,)
558-
batch of index of beta
559-
560-
Returns
561-
-------
562-
563-
grad : np.ndarray(ns,)
564-
partial grad F in beta
565-
566-
Examples
567-
--------
568-
569-
>>> n_source = 7
570-
>>> n_target = 4
571-
>>> reg = 1
572-
>>> numItermax = 20000
573-
>>> lr = 0.1
574-
>>> batch_size = 3
575-
>>> log = True
576-
>>> a = ot.utils.unif(n_source)
577-
>>> b = ot.utils.unif(n_target)
578-
>>> rng = np.random.RandomState(0)
579-
>>> X_source = rng.randn(n_source, 2)
580-
>>> Y_target = rng.randn(n_target, 2)
581-
>>> M = ot.dist(X_source, Y_target)
582-
>>> sgd_dual_pi, log = stochastic.solve_dual_entropic(a, b, M, reg,
583-
batch_size,
584-
numItermax, lr, log)
585-
>>> print(log['alpha'], log['beta'])
586-
>>> print(sgd_dual_pi)
587-
588-
References
589-
----------
590-
591-
[Seguy et al., 2018] :
592-
International Conference on Learning Representation (2018),
593-
arXiv preprint arxiv:1711.02283.
518+
G = - (np.exp((alpha[batch_alpha, None] + beta[None, batch_beta] -
519+
M[batch_alpha, :][:, batch_beta]) / reg) * a[batch_alpha, None] *
520+
b[None, batch_beta])
521+
grad_beta = np.zeros(np.shape(M)[1])
522+
grad_alpha = np.zeros(np.shape(M)[0])
523+
grad_beta[batch_beta] = (b[batch_beta] * len(batch_alpha) / np.shape(M)[0] +
524+
G.sum(0))
525+
grad_alpha[batch_alpha] = (a[batch_alpha] * len(batch_beta) /
526+
np.shape(M)[1] + G.sum(1))
594527

595-
'''
596-
597-
grad_beta = np.zeros(batch_size)
598-
grad_beta[:] = batch_size
599-
for i in batch_alpha:
600-
grad_beta -= np.exp((alpha[i] +
601-
beta[batch_beta] - M[i, batch_beta]) / reg)
602-
return grad_beta
528+
return grad_alpha, grad_beta
603529

604530

605-
def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
606-
alternate=True):
531+
def sgd_entropic_regularization(M, reg, a, b, batch_size, numItermax, lr):
607532
'''
608533
Compute the sgd algorithm to solve the regularized discrete measures
609534
optimal transport dual problem
@@ -628,6 +553,10 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
628553
cost matrix
629554
reg : float number,
630555
Regularization term > 0
556+
alpha : np.ndarray(ns,)
557+
dual variable
558+
beta : np.ndarray(nt,)
559+
dual variable
631560
batch_size : int number
632561
size of the batch
633562
numItermax : int number
@@ -677,35 +606,17 @@ def sgd_entropic_regularization(M, reg, batch_size, numItermax, lr,
677606

678607
n_source = np.shape(M)[0]
679608
n_target = np.shape(M)[1]
680-
cur_alpha = np.random.randn(n_source)
681-
cur_beta = np.random.randn(n_target)
682-
if alternate:
683-
for cur_iter in range(numItermax):
684-
k = np.sqrt(cur_iter + 1)
685-
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
686-
batch_beta = np.random.choice(n_target, batch_size, replace=False)
687-
grad_F_alpha = batch_grad_dual_alpha(M, reg, cur_alpha, cur_beta,
688-
batch_size, batch_alpha,
689-
batch_beta)
690-
cur_alpha[batch_alpha] += (lr / k) * grad_F_alpha
691-
grad_F_beta = batch_grad_dual_beta(M, reg, cur_alpha, cur_beta,
692-
batch_size, batch_alpha,
693-
batch_beta)
694-
cur_beta[batch_beta] += (lr / k) * grad_F_beta
695-
696-
else:
697-
for cur_iter in range(numItermax):
698-
k = np.sqrt(cur_iter + 1)
699-
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
700-
batch_beta = np.random.choice(n_target, batch_size, replace=False)
701-
grad_F_alpha = batch_grad_dual_alpha(M, reg, cur_alpha, cur_beta,
702-
batch_size, batch_alpha,
703-
batch_beta)
704-
grad_F_beta = batch_grad_dual_beta(M, reg, cur_alpha, cur_beta,
705-
batch_size, batch_alpha,
706-
batch_beta)
707-
cur_alpha[batch_alpha] += (lr / k) * grad_F_alpha
708-
cur_beta[batch_beta] += (lr / k) * grad_F_beta
609+
cur_alpha = np.zeros(n_source)
610+
cur_beta = np.zeros(n_target)
611+
for cur_iter in range(numItermax):
612+
k = np.sqrt(cur_iter / 100 + 1)
613+
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
614+
batch_beta = np.random.choice(n_target, batch_size, replace=False)
615+
update_alpha, update_beta = batch_grad_dual(M, reg, a, b, cur_alpha,
616+
cur_beta, batch_size,
617+
batch_alpha, batch_beta)
618+
cur_alpha += (lr / k) * update_alpha
619+
cur_beta += (lr / k) * update_beta
709620

710621
return cur_alpha, cur_beta
711622

@@ -787,7 +698,7 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
787698
arXiv preprint arxiv:1711.02283.
788699
'''
789700

790-
opt_alpha, opt_beta = sgd_entropic_regularization(M, reg, batch_size,
701+
opt_alpha, opt_beta = sgd_entropic_regularization(M, reg, a, b, batch_size,
791702
numItermax, lr)
792703
pi = (np.exp((opt_alpha[:, None] + opt_beta[None, :] - M[:, :]) / reg) *
793704
a[:, None] * b[None, :])

0 commit comments

Comments
 (0)